From 260d5c72cf5c2b3559fd1149b02832cc54696dcc Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Fri, 3 Oct 2025 12:49:18 +0530 Subject: [PATCH 01/37] Draft one: modify model, added crud, route and fix deps, migration --- .../d209cddac1fa_refactor_api_key_table.py | 43 +++ backend/app/api/deps.py | 63 +++- backend/app/api/routes/api_keys.py | 121 ++------ backend/app/crud/__init__.py | 11 +- backend/app/crud/api_key.py | 280 ++++++++---------- backend/app/crud/onboarding.py | 9 +- backend/app/models/__init__.py | 6 +- backend/app/models/api_key.py | 30 +- backend/app/models/organization.py | 3 - backend/app/models/project.py | 3 - backend/app/models/user.py | 8 +- backend/app/seed_data/seed_data.py | 35 ++- 12 files changed, 314 insertions(+), 298 deletions(-) create mode 100644 backend/app/alembic/versions/d209cddac1fa_refactor_api_key_table.py diff --git a/backend/app/alembic/versions/d209cddac1fa_refactor_api_key_table.py b/backend/app/alembic/versions/d209cddac1fa_refactor_api_key_table.py new file mode 100644 index 00000000..93c2db46 --- /dev/null +++ b/backend/app/alembic/versions/d209cddac1fa_refactor_api_key_table.py @@ -0,0 +1,43 @@ +"""Refactor API key table + +Revision ID: d209cddac1fa +Revises: c6fb6d0b5897 +Create Date: 2025-10-03 11:35:13.012517 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision = 'd209cddac1fa' +down_revision = 'c6fb6d0b5897' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('apikey', sa.Column('key_prefix', sqlmodel.sql.sqltypes.AutoString(), nullable=False)) + op.add_column('apikey', sa.Column('key_hash', sqlmodel.sql.sqltypes.AutoString(), nullable=False)) + op.add_column('apikey', sa.Column('last_used_at', sa.DateTime(), nullable=True)) + + op.add_column('apikey', sa.Column('new_id', sa.Uuid(), nullable=True)) + op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto;") + op.execute("UPDATE apikey SET new_id = gen_random_uuid();") + + # Replace old PK + op.drop_constraint('apikey_pkey', 'apikey', type_='primary') + op.drop_column('apikey', 'id') + op.alter_column('apikey', 'new_id', new_column_name='id', nullable=False) + op.create_primary_key('apikey_pkey', 'apikey', ['id']) + + op.drop_index('ix_apikey_key', table_name='apikey') + op.create_index(op.f('ix_apikey_key_prefix'), 'apikey', ['key_prefix'], unique=True) + op.drop_column('apikey', 'key') + # ### end Alembic commands ### + + +def downgrade(): + pass diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index fd946e31..33c6bef4 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -14,10 +14,11 @@ from app.core.db import engine from app.utils import APIResponse from app.crud.organization import validate_organization -from app.crud.api_key import get_api_key_by_value +from app.crud.api_key import verify_api_key from app.models import ( TokenPayload, User, + UserContext, UserProjectOrg, UserOrganization, ProjectUser, @@ -48,7 +49,7 @@ def get_current_user( """Authenticate user via API Key first, fallback to JWT token. Returns only User.""" if api_key: - api_key_record = get_api_key_by_value(session, api_key) + api_key_record = verify_api_key(session, api_key) if not api_key_record: raise HTTPException(status_code=401, detail="Invalid API Key") @@ -94,7 +95,7 @@ def get_current_user_org( organization_id = None api_key = request.headers.get("X-API-KEY") if api_key: - api_key_record = get_api_key_by_value(session, api_key) + api_key_record = verify_api_key(session, api_key) if api_key_record: validate_organization(session, api_key_record.organization_id) organization_id = api_key_record.organization_id @@ -115,7 +116,7 @@ def get_current_user_org_project( project_id = None if api_key: - api_key_record = get_api_key_by_value(session, api_key) + api_key_record = verify_api_key(session, api_key) if api_key_record: validate_organization(session, api_key_record.organization_id) organization_id = api_key_record.organization_id @@ -150,6 +151,60 @@ def get_current_active_superuser_org(current_user: CurrentUserOrg) -> User: return current_user +def get_user_context( + session: SessionDep, + token: TokenDep, + api_key: Annotated[str, Depends(api_key_header)], +) -> UserContext: + """ + Verify valid authentication (API Key or JWT token) and return authenticated user context. + Returns UserContext with user info, project_id, and organization_id. + Authorization logic should be handled in routes. + """ + if api_key: + api_key_record = verify_api_key(session, api_key) + if not api_key_record: + raise HTTPException(status_code=401, detail="Invalid API Key") + + user = session.get(User, api_key_record.user_id) + if not user: + raise HTTPException( + status_code=404, detail="User linked to API Key not found" + ) + + user_context = UserContext( + **user.model_dump(), + project_id=api_key_record.project_id, + organization_id=api_key_record.organization_id, + ) + return user_context + + elif token: + try: + payload = jwt.decode( + token, settings.SECRET_KEY, algorithms=[security.ALGORITHM] + ) + token_data = TokenPayload(**payload) + except (InvalidTokenError, ValidationError): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Could not validate credentials", + ) + + user = session.get(User, token_data.sub) + if not user: + raise HTTPException(status_code=404, detail="User not found") + if not user.is_active: + raise HTTPException(status_code=400, detail="Inactive user") + + user_context = UserContext(**user.model_dump()) + + return user_context + + else: + raise HTTPException(status_code=401, detail="Invalid Authorization format") + + def verify_user_project_organization( db: SessionDep, current_user: CurrentUserOrg, diff --git a/backend/app/api/routes/api_keys.py b/backend/app/api/routes/api_keys.py index 125df075..cb98b7c7 100644 --- a/backend/app/api/routes/api_keys.py +++ b/backend/app/api/routes/api_keys.py @@ -1,116 +1,53 @@ -import logging -from fastapi import APIRouter, Depends, HTTPException +from typing import Annotated +from fastapi import APIRouter, Depends, Query from sqlmodel import Session -from app.api.deps import get_db, get_current_active_superuser -from app.crud.api_key import ( - create_api_key, - get_api_key, - delete_api_key, - get_api_keys_by_project, - get_api_key_by_project_user, -) -from app.crud.project import validate_project -from app.models import APIKeyPublic, User + +from app.api.deps import get_db, get_current_active_superuser, get_user_context +from app.crud.api_key import APIKeyCrud +from app.models import APIKeyPublic, APIKeyCreateResponse, User, UserContext from app.utils import APIResponse -from app.core.exception_handlers import HTTPException -logger = logging.getLogger(__name__) -router = APIRouter(prefix="/apikeys", tags=["API Keys"]) +router = APIRouter(prefix="/api-keys", tags=["API Keys"]) -@router.post("/", response_model=APIResponse[APIKeyPublic]) -def create_key( +@router.post("/", response_model=APIResponse[APIKeyCreateResponse], status_code=201) +def create_api_key_route( project_id: int, - user_id: int, session: Session = Depends(get_db), current_user: User = Depends(get_current_active_superuser), ): """ - Generate a new API key for the user's organization. - """ - project = validate_project(session, project_id) + Create a new API key for the current project. - existing_api_key = get_api_key_by_project_user(session, project_id, user_id) - if existing_api_key: - logger.warning( - f"[create_key] API key already exists | project_id={project_id}, user_id={user_id}" - ) - raise HTTPException( - status_code=400, - detail="API Key already exists for this user and project.", - ) + The raw API key is returned only once during creation. + Store it securely as it cannot be retrieved again. + """ + api_key_crud = APIKeyCrud(session=session, project_id=project_id) + raw_key, api_key = api_key_crud.create( + user_id=current_user.id, + ) - api_key = create_api_key( - session, - organization_id=project.organization_id, - user_id=user_id, - project_id=project_id, + api_key = APIKeyCreateResponse( + **api_key.model_dump(), key=raw_key ) + return APIResponse.success_response(api_key) @router.get("/", response_model=APIResponse[list[APIKeyPublic]]) -def list_keys( - project_id: int, - session: Session = Depends(get_db), - current_user: User = Depends(get_current_active_superuser), -): - """ - Retrieve all API keys for the given project. Superusers get all keys; - regular users get only their own. - """ - project = validate_project(session=session, project_id=project_id) - - if current_user.is_superuser: - api_keys = get_api_keys_by_project(session=session, project_id=project_id) - else: - user_api_key = get_api_key_by_project_user( - session=session, project_id=project_id, user_id=current_user.id - ) - api_keys = [user_api_key] if user_api_key else [] - - if not api_keys: - logger.warning(f"[list_keys] No API keys found | project_id={project_id}") - raise HTTPException( - status_code=404, - detail="No API keys found for this project.", - ) - - return APIResponse.success_response(api_keys) - - -@router.get("/{api_key_id}", response_model=APIResponse[APIKeyPublic]) -def get_key( - api_key_id: int, +def list_api_keys_route( session: Session = Depends(get_db), current_user: User = Depends(get_current_active_superuser), + skip: int = Query(0, ge=0, description="Number of records to skip"), + limit: int = Query(100, ge=1, le=100, description="Maximum records to return"), ): """ - Retrieve an API key by ID. - """ - api_key = get_api_key(session, api_key_id) - if not api_key: - logger.warning(f"[get_key] API key not found | api_key_id={api_key_id}") - raise HTTPException(404, "API Key does not exist") - - return APIResponse.success_response(api_key) + List all API keys for the current project. - -@router.delete("/{api_key_id}", response_model=APIResponse[dict]) -def revoke_key( - api_key_id: int, - session: Session = Depends(get_db), - current_user: User = Depends(get_current_active_superuser), -): - """ - Soft delete an API key (revoke access). + Returns masked keys for security - the full key is only shown during creation. + Supports pagination via skip and limit parameters. """ - api_key = get_api_key(session, api_key_id) - if not api_key: - logger.warning( - f"[apikey.revoke] API key not found or already deleted | api_key_id={api_key_id}" - ) - raise HTTPException(404, "API key not found or already deleted") + crud = APIKeyCrud(session, current_user.project_id) + api_keys = crud.read_all(skip=skip, limit=limit) - delete_api_key(session, api_key_id) - return APIResponse.success_response({"message": "API key revoked successfully"}) + return APIResponse.success_response(api_keys) diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index 43ef1556..487495d9 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -27,16 +27,7 @@ validate_project, ) -from .api_key import ( - create_api_key, - generate_api_key, - get_api_key, - get_api_key_by_value, - get_api_keys_by_project, - get_api_key_by_project_user, - delete_api_key, - get_api_key_by_user_id, -) +from .api_key import APIKeyCrud, generate_api_key from .credentials import ( set_creds_for_org, diff --git a/backend/app/crud/api_key.py b/backend/app/crud/api_key.py index 7a8b7c16..7df8640d 100644 --- a/backend/app/crud/api_key.py +++ b/backend/app/crud/api_key.py @@ -1,182 +1,152 @@ -import uuid -import secrets import logging -from sqlmodel import Session, select -from app.core.security import ( - get_password_hash, - encrypt_api_key, - decrypt_api_key, -) -from app.core import settings -from app.core.util import now -from app.core.exception_handlers import HTTPException -from app.models.api_key import APIKey, APIKeyPublic - -logger = logging.getLogger(__name__) - +import secrets +from typing import Optional, Tuple -def generate_api_key() -> tuple[str, str]: - """Generate a new API key and its hash.""" - raw_key = "ApiKey " + secrets.token_urlsafe(32) - hashed_key = get_password_hash(raw_key) - return raw_key, hashed_key +from sqlmodel import Session, select, and_ +from fastapi import HTTPException +from passlib.context import CryptContext +from app.models import APIKey +from app.crud import get_project_by_id -def create_api_key( - session: Session, organization_id: int, user_id: int, project_id: int -) -> APIKeyPublic: - """ - Generates a new API key for an organization and associates it with a user. - Returns the API key details with the raw key (shown only once). - """ - # Generate raw key and its hash using the helper function - raw_key, hashed_key = generate_api_key() - encrypted_key = encrypt_api_key( - raw_key - ) # Encrypt the raw key instead of hashed key - - # Create API key record with encrypted raw key - api_key = APIKey( - key=encrypted_key, # Store the encrypted raw key - organization_id=organization_id, - user_id=user_id, - project_id=project_id, - ) - - session.add(api_key) - session.commit() - session.refresh(api_key) - - # Set the raw key in the response (shown only once) - api_key_dict = api_key.model_dump() - api_key_dict["key"] = raw_key # Return the raw key to the user - - logger.info( - f"[create_api_key] API key creation completed | {{'api_key_id': {api_key.id}, 'user_id': {user_id}, 'project_id': {project_id}}}" - ) - return APIKeyPublic.model_validate(api_key_dict) - - -def get_api_key(session: Session, api_key_id: int) -> APIKeyPublic | None: - """ - Retrieves an API key by its ID if it exists and is not deleted. - Returns the API key in its original format. - """ - api_key = session.exec( - select(APIKey).where(APIKey.id == api_key_id, APIKey.is_deleted == False) - ).first() +logger = logging.getLogger(__name__) - if api_key: - # Create a copy of the API key data - api_key_dict = api_key.model_dump() - # Decrypt the key - decrypted_key = decrypt_api_key(api_key.key) - api_key_dict["key"] = decrypted_key - return APIKeyPublic.model_validate(api_key_dict) - logger.warning(f"[get_api_key] API key not found | {{'api_key_id': {api_key_id}}}") - return None +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") -def delete_api_key(session: Session, api_key_id: int) -> None: +def verify_api_key(session: Session, raw_key: str) -> Optional[APIKey]: """ - Soft deletes (revokes) an API key by marking it as deleted. + Verify an API key by extracting the prefix and checking the hash. + Returns the APIKey record if valid, None otherwise. """ - api_key = session.get(APIKey, api_key_id) - - if not api_key: - logger.warning( - f"[delete_api_key] API key not found | {{'api_key_id': {api_key_id}}}" + try: + # Check format: "ApiKey {key_prefix}{random_key}" + if not raw_key.startswith("ApiKey "): + return None + + # Extract the key part after "ApiKey " + key_part = raw_key[7:] # Remove "ApiKey " prefix + + # Extract key_prefix (first 22 chars - urlsafe base64 of 16 bytes) + if len(key_part) < 22: + return None + + key_prefix = key_part[:22] + + # Find API key by prefix + statement = select(APIKey).where( + and_( + APIKey.key_prefix == key_prefix, + APIKey.is_deleted.is_(False), + ) ) - return - - api_key.is_deleted = True - api_key.deleted_at = now() - api_key.updated_at = now() - - session.add(api_key) - session.commit() - logger.info( - f"[delete_api_key] API key soft deleted successfully | {{'api_key_id': {api_key_id}}}" - ) + api_key_record = session.exec(statement).one_or_none() + if not api_key_record: + return None -def get_api_key_by_value(session: Session, api_key_value: str) -> APIKeyPublic | None: - """ - Retrieve an API Key record by verifying the provided key against stored hashes. - Returns the API key in its original format. - """ - # Get all active API keys - api_keys = session.exec(select(APIKey).where(APIKey.is_deleted == False)).all() + # Verify hash + if pwd_context.verify(raw_key, api_key_record.key_hash): + return api_key_record - for api_key in api_keys: - decrypted_key = decrypt_api_key(api_key.key) - if api_key_value == decrypted_key: - api_key_dict = api_key.model_dump() - api_key_dict["key"] = decrypted_key - return APIKeyPublic.model_validate(api_key_dict) + return None - logger.warning( - f"[get_api_key_by_value] API key not found | {{'action': 'not_found'}}" - ) - return None + except Exception as e: + logger.error(f"[verify_api_key] Error verifying API key: {str(e)}", exc_info=True) + return None -def get_api_key_by_project_user( - session: Session, project_id: int, user_id: uuid.UUID -) -> APIKeyPublic | None: +def generate_api_key() -> Tuple[str, str, str]: """ - Retrieves the single API key associated with a project. + Generate a new API key with key_prefix and hash. """ - statement = select(APIKey).where( - APIKey.user_id == user_id, - APIKey.project_id == project_id, - APIKey.is_deleted == False, - ) - api_key = session.exec(statement).first() + random_key = secrets.token_urlsafe(32) + key_prefix = secrets.token_urlsafe(16) - if api_key: - api_key_dict = api_key.model_dump() - api_key_dict["key"] = decrypt_api_key(api_key.key) - return APIKeyPublic.model_validate(api_key_dict) + raw_key = f"ApiKey {key_prefix}{random_key}" - logger.warning( - f"[get_api_key_by_project_user] API key not found | {{'project_id': {project_id}, 'user_id': '{user_id}'}}" - ) - return None + key_hash = pwd_context.hash(raw_key) + return raw_key, key_prefix, key_hash -def get_api_keys_by_project(session: Session, project_id: int) -> list[APIKeyPublic]: - """ - Retrieves all API keys associated with a project. - """ - statement = select(APIKey).where( - APIKey.project_id == project_id, APIKey.is_deleted == False - ) - api_keys = session.exec(statement).all() - - result = [] - for key in api_keys: - key_dict = key.model_dump() - key_dict["key"] = decrypt_api_key(key.key) - result.append(APIKeyPublic.model_validate(key_dict)) - return result - - -def get_api_key_by_user_id(session: Session, user_id: int) -> APIKeyPublic | None: +class APIKeyCrud: """ - Retrieves the API key associated with a user by their user_id. + CRUD operations for API keys scoped to a project. """ - api_key = ( - session.query(APIKey) - .filter(APIKey.user_id == user_id, APIKey.is_deleted == False) - .first() - ) - - if not api_key: - return None - key_dict = api_key.model_dump() - key_dict["key"] = decrypt_api_key(api_key.key) - return APIKeyPublic.model_validate(key_dict) + def __init__(self, session: Session, project_id: int): + self.session = session + self.project_id = project_id + + def read_one(self, key_prefix: str) -> Optional[APIKey]: + """ + Retrieve a single non-deleted API key by its key_prefix. + """ + statement = select(APIKey).where( + and_( + APIKey.key_prefix == key_prefix, + APIKey.project_id == self.project_id, + APIKey.is_deleted.is_(False), + ) + ) + return self.session.exec(statement).one_or_none() + + def read_all(self, skip: int = 0, limit: int = 100) -> list[APIKey]: + """ + Read all non-deleted API keys for the project. + """ + statement = ( + select(APIKey) + .where( + and_( + APIKey.project_id == self.project_id, + APIKey.is_deleted.is_(False), + ) + ) + .offset(skip) + .limit(limit) + ) + return self.session.exec(statement).all() + + def create( + self, user_id: int + ) -> Tuple[str, APIKey]: + """ + Create a new API key for the project. + """ + try: + raw_key, key_prefix, key_hash = generate_api_key() + + project = get_project_by_id(session=self.session, project_id=self.project_id) + + api_key = APIKey( + key_prefix=key_prefix, + key_hash=key_hash, + user_id=user_id, + organization_id=project.organization_id, + project_id=self.project_id, + ) + + self.session.add(api_key) + self.session.commit() + self.session.refresh(api_key) + + logger.info( + f"[APIKeyCrud.create_api_key] API key created successfully | " + f"{{'api_key_id': '{api_key.id}', 'project_id': {self.project_id}, 'user_id': {user_id}}}" + ) + + return raw_key, api_key + + except Exception as e: + logger.error( + f"[APIKeyCrud.create_api_key] Failed to create API key | " + f"{{'project_id': {self.project_id}, 'user_id': {user_id}, 'error': '{str(e)}'}}", + exc_info=True, + ) + self.session.rollback() + raise HTTPException( + status_code=500, detail=f"Failed to create API key: {str(e)}" + ) diff --git a/backend/app/crud/onboarding.py b/backend/app/crud/onboarding.py index 8788b083..e77894ba 100644 --- a/backend/app/crud/onboarding.py +++ b/backend/app/crud/onboarding.py @@ -90,15 +90,16 @@ def onboard_project( session.add(user) session.flush() - raw_key, _ = generate_api_key() - encrypted_key = encrypt_api_key(raw_key) + raw_key, key_prefix, key_hash = generate_api_key() api_key = APIKey( - key=encrypted_key, # Store the encrypted raw key - organization_id=organization.id, + key_prefix=key_prefix, + key_hash=key_hash, user_id=user.id, + organization_id=project.organization_id, project_id=project.id, ) + session.add(api_key) credential = None diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 94c45ba3..a885bfb5 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -1,6 +1,9 @@ from sqlmodel import SQLModel from .auth import Token, TokenPayload + +from .api_key import APIKey, APIKeyBase, APIKeyPublic, APIKeyCreateResponse + from .collection import Collection from .document import ( Document, @@ -33,8 +36,6 @@ ProjectUpdate, ) -from .api_key import APIKey, APIKeyBase, APIKeyPublic - from .organization import ( Organization, OrganizationCreate, @@ -47,6 +48,7 @@ NewPassword, User, UserCreate, + UserContext, UserOrganization, UserProjectOrg, UserPublic, diff --git a/backend/app/models/api_key.py b/backend/app/models/api_key.py index 22387e1b..d6c9befa 100644 --- a/backend/app/models/api_key.py +++ b/backend/app/models/api_key.py @@ -1,5 +1,6 @@ -import uuid +from uuid import UUID, uuid4 import secrets +import base64 from datetime import datetime from typing import Optional, List from sqlmodel import SQLModel, Field, Relationship @@ -15,24 +16,29 @@ class APIKeyBase(SQLModel): foreign_key="project.id", nullable=False, ondelete="CASCADE" ) user_id: int = Field(foreign_key="user.id", nullable=False, ondelete="CASCADE") - key: str = Field( - default_factory=lambda: secrets.token_urlsafe(32), unique=True, index=True - ) class APIKeyPublic(APIKeyBase): - id: int - inserted_at: datetime = Field(default_factory=now, nullable=False) + id: UUID + key_prefix: str # Expose key_id for display (partial key identifier) + last_used_at: datetime | None + inserted_at: datetime + updated_at: datetime + + +class APIKeyCreateResponse(APIKeyPublic): + """Response model when creating an API key includes the raw key only once""" + key: str class APIKey(APIKeyBase, table=True): - id: int = Field(default=None, primary_key=True) + id: UUID = Field(default_factory=uuid4, primary_key=True) + + key_prefix: str = Field(unique=True, index=True, nullable=False) # Unique identifier from the key + key_hash: str = Field(nullable=False) # bcrypt hash of the full key + + last_used_at: datetime | None = Field(default=None, nullable=True) inserted_at: datetime = Field(default_factory=now, nullable=False) updated_at: datetime = Field(default_factory=now, nullable=False) is_deleted: bool = Field(default=False, nullable=False) deleted_at: Optional[datetime] = Field(default=None, nullable=True) - - # Relationships - organization: "Organization" = Relationship(back_populates="api_keys") - project: "Project" = Relationship(back_populates="api_keys") - user: "User" = Relationship(back_populates="api_keys") diff --git a/backend/app/models/organization.py b/backend/app/models/organization.py index 39342fda..4c729cdd 100644 --- a/backend/app/models/organization.py +++ b/backend/app/models/organization.py @@ -37,9 +37,6 @@ class Organization(OrganizationBase, table=True): updated_at: datetime = Field(default_factory=now, nullable=False) # Relationship back to Creds - api_keys: list["APIKey"] = Relationship( - back_populates="organization", cascade_delete=True - ) creds: list["Credential"] = Relationship( back_populates="organization", cascade_delete=True ) diff --git a/backend/app/models/project.py b/backend/app/models/project.py index 8f0acfa6..b2d485f9 100644 --- a/backend/app/models/project.py +++ b/backend/app/models/project.py @@ -48,9 +48,6 @@ class Project(ProjectBase, table=True): assistants: list["Assistant"] = Relationship( back_populates="project", cascade_delete=True ) - api_keys: list["APIKey"] = Relationship( - back_populates="project", cascade_delete=True - ) organization: Optional["Organization"] = Relationship(back_populates="project") collections: list["Collection"] = Relationship( back_populates="project", cascade_delete=True diff --git a/backend/app/models/user.py b/backend/app/models/user.py index fa526ab5..13818307 100644 --- a/backend/app/models/user.py +++ b/backend/app/models/user.py @@ -54,8 +54,6 @@ class User(UserBase, table=True): projects: list["ProjectUser"] = Relationship( back_populates="user", cascade_delete=True ) - api_keys: list["APIKey"] = Relationship(back_populates="user", cascade_delete=True) - class UserOrganization(UserBase): id: int @@ -66,6 +64,12 @@ class UserProjectOrg(UserOrganization): project_id: int +class UserContext(UserBase): + id: int + project_id: int | None = None + organization_id: int | None = None + + # Properties to return via API, id is always required class UserPublic(UserBase): id: int diff --git a/backend/app/seed_data/seed_data.py b/backend/app/seed_data/seed_data.py index 8139f9c8..d06e9427 100644 --- a/backend/app/seed_data/seed_data.py +++ b/backend/app/seed_data/seed_data.py @@ -9,7 +9,7 @@ from app.core.db import engine from app.core import settings -from app.core.security import encrypt_api_key, get_password_hash +from app.core.security import get_password_hash, encrypt_credentials from app.models import ( APIKey, Organization, @@ -49,7 +49,6 @@ class APIKeyData(BaseModel): api_key: str is_deleted: bool deleted_at: Optional[str] = None - created_at: Optional[str] = None class CredentialData(BaseModel): @@ -184,21 +183,34 @@ def create_api_key(session: Session, api_key_data_raw: dict) -> APIKey: ).first() if not user: raise ValueError(f"User '{api_key_data.user_email}' not found") - encrypted_api_key = encrypt_api_key(api_key_data.api_key) + + # Extract key_prefix from the provided API key and hash the full key + # API key format: "ApiKey {key_prefix}{random_key}" where key_prefix is 16 chars + raw_key = api_key_data.api_key + if not raw_key.startswith("ApiKey "): + raise ValueError(f"Invalid API key format: {raw_key}") + + # Extract the key_prefix (first 16 characters after "ApiKey ") + key_portion = raw_key[7:] # Remove "ApiKey " prefix + + key_prefix = key_portion[:8] # First 8 characters as prefix + + # Hash the full raw key + from passlib.context import CryptContext + pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + key_hash = pwd_context.hash(raw_key) + api_key = APIKey( organization_id=organization.id, project_id=project.id, user_id=user.id, - key=encrypted_api_key, + key_prefix=key_prefix, + key_hash=key_hash, is_deleted=api_key_data.is_deleted, deleted_at=api_key_data.deleted_at, ) - if api_key_data.created_at: - api_key.created_at = datetime.fromisoformat( - api_key_data.created_at.replace("Z", "+00:00") - ) session.add(api_key) - session.flush() # Ensure ID is assigned + session.flush() return api_key except Exception as e: logging.error(f"Error creating API key: {e}") @@ -229,8 +241,9 @@ def create_credential(session: Session, credential_data_raw: dict) -> Credential if not project: raise ValueError(f"Project '{credential_data.project_name}' not found") - # Encrypt the credential data - encrypted_credential = encrypt_api_key(credential_data.credential) + # Encrypt the credential data - convert string to dict first, then encrypt + credential_dict = json.loads(credential_data.credential) + encrypted_credential = encrypt_credentials(credential_dict) credential = Credential( is_active=credential_data.is_active, From fc36a288e3b78278618f762fe702339842e83be2 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Fri, 3 Oct 2025 12:50:14 +0530 Subject: [PATCH 02/37] pre commit --- .../d209cddac1fa_refactor_api_key_table.py | 36 +++++++++++-------- backend/app/api/routes/api_keys.py | 4 +-- backend/app/crud/api_key.py | 12 ++++--- backend/app/models/api_key.py | 5 ++- backend/app/models/user.py | 1 + backend/app/seed_data/seed_data.py | 7 ++-- 6 files changed, 38 insertions(+), 27 deletions(-) diff --git a/backend/app/alembic/versions/d209cddac1fa_refactor_api_key_table.py b/backend/app/alembic/versions/d209cddac1fa_refactor_api_key_table.py index 93c2db46..4e5428fa 100644 --- a/backend/app/alembic/versions/d209cddac1fa_refactor_api_key_table.py +++ b/backend/app/alembic/versions/d209cddac1fa_refactor_api_key_table.py @@ -11,31 +11,37 @@ # revision identifiers, used by Alembic. -revision = 'd209cddac1fa' -down_revision = 'c6fb6d0b5897' +revision = "d209cddac1fa" +down_revision = "c6fb6d0b5897" branch_labels = None depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.add_column('apikey', sa.Column('key_prefix', sqlmodel.sql.sqltypes.AutoString(), nullable=False)) - op.add_column('apikey', sa.Column('key_hash', sqlmodel.sql.sqltypes.AutoString(), nullable=False)) - op.add_column('apikey', sa.Column('last_used_at', sa.DateTime(), nullable=True)) - - op.add_column('apikey', sa.Column('new_id', sa.Uuid(), nullable=True)) + op.add_column( + "apikey", + sa.Column("key_prefix", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + ) + op.add_column( + "apikey", + sa.Column("key_hash", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + ) + op.add_column("apikey", sa.Column("last_used_at", sa.DateTime(), nullable=True)) + + op.add_column("apikey", sa.Column("new_id", sa.Uuid(), nullable=True)) op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto;") op.execute("UPDATE apikey SET new_id = gen_random_uuid();") # Replace old PK - op.drop_constraint('apikey_pkey', 'apikey', type_='primary') - op.drop_column('apikey', 'id') - op.alter_column('apikey', 'new_id', new_column_name='id', nullable=False) - op.create_primary_key('apikey_pkey', 'apikey', ['id']) - - op.drop_index('ix_apikey_key', table_name='apikey') - op.create_index(op.f('ix_apikey_key_prefix'), 'apikey', ['key_prefix'], unique=True) - op.drop_column('apikey', 'key') + op.drop_constraint("apikey_pkey", "apikey", type_="primary") + op.drop_column("apikey", "id") + op.alter_column("apikey", "new_id", new_column_name="id", nullable=False) + op.create_primary_key("apikey_pkey", "apikey", ["id"]) + + op.drop_index("ix_apikey_key", table_name="apikey") + op.create_index(op.f("ix_apikey_key_prefix"), "apikey", ["key_prefix"], unique=True) + op.drop_column("apikey", "key") # ### end Alembic commands ### diff --git a/backend/app/api/routes/api_keys.py b/backend/app/api/routes/api_keys.py index cb98b7c7..6a06b600 100644 --- a/backend/app/api/routes/api_keys.py +++ b/backend/app/api/routes/api_keys.py @@ -27,9 +27,7 @@ def create_api_key_route( user_id=current_user.id, ) - api_key = APIKeyCreateResponse( - **api_key.model_dump(), key=raw_key - ) + api_key = APIKeyCreateResponse(**api_key.model_dump(), key=raw_key) return APIResponse.success_response(api_key) diff --git a/backend/app/crud/api_key.py b/backend/app/crud/api_key.py index 7df8640d..161b7234 100644 --- a/backend/app/crud/api_key.py +++ b/backend/app/crud/api_key.py @@ -53,7 +53,9 @@ def verify_api_key(session: Session, raw_key: str) -> Optional[APIKey]: return None except Exception as e: - logger.error(f"[verify_api_key] Error verifying API key: {str(e)}", exc_info=True) + logger.error( + f"[verify_api_key] Error verifying API key: {str(e)}", exc_info=True + ) return None @@ -110,16 +112,16 @@ def read_all(self, skip: int = 0, limit: int = 100) -> list[APIKey]: ) return self.session.exec(statement).all() - def create( - self, user_id: int - ) -> Tuple[str, APIKey]: + def create(self, user_id: int) -> Tuple[str, APIKey]: """ Create a new API key for the project. """ try: raw_key, key_prefix, key_hash = generate_api_key() - project = get_project_by_id(session=self.session, project_id=self.project_id) + project = get_project_by_id( + session=self.session, project_id=self.project_id + ) api_key = APIKey( key_prefix=key_prefix, diff --git a/backend/app/models/api_key.py b/backend/app/models/api_key.py index d6c9befa..e3a0c136 100644 --- a/backend/app/models/api_key.py +++ b/backend/app/models/api_key.py @@ -28,13 +28,16 @@ class APIKeyPublic(APIKeyBase): class APIKeyCreateResponse(APIKeyPublic): """Response model when creating an API key includes the raw key only once""" + key: str class APIKey(APIKeyBase, table=True): id: UUID = Field(default_factory=uuid4, primary_key=True) - key_prefix: str = Field(unique=True, index=True, nullable=False) # Unique identifier from the key + key_prefix: str = Field( + unique=True, index=True, nullable=False + ) # Unique identifier from the key key_hash: str = Field(nullable=False) # bcrypt hash of the full key last_used_at: datetime | None = Field(default=None, nullable=True) diff --git a/backend/app/models/user.py b/backend/app/models/user.py index 13818307..f3523948 100644 --- a/backend/app/models/user.py +++ b/backend/app/models/user.py @@ -55,6 +55,7 @@ class User(UserBase, table=True): back_populates="user", cascade_delete=True ) + class UserOrganization(UserBase): id: int organization_id: int | None diff --git a/backend/app/seed_data/seed_data.py b/backend/app/seed_data/seed_data.py index d06e9427..3ba195d3 100644 --- a/backend/app/seed_data/seed_data.py +++ b/backend/app/seed_data/seed_data.py @@ -189,17 +189,18 @@ def create_api_key(session: Session, api_key_data_raw: dict) -> APIKey: raw_key = api_key_data.api_key if not raw_key.startswith("ApiKey "): raise ValueError(f"Invalid API key format: {raw_key}") - + # Extract the key_prefix (first 16 characters after "ApiKey ") key_portion = raw_key[7:] # Remove "ApiKey " prefix key_prefix = key_portion[:8] # First 8 characters as prefix - + # Hash the full raw key from passlib.context import CryptContext + pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") key_hash = pwd_context.hash(raw_key) - + api_key = APIKey( organization_id=organization.id, project_id=project.id, From 600d84a142ae6ca574970788fa696b06d3c0fa96 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Fri, 3 Oct 2025 14:33:19 +0530 Subject: [PATCH 03/37] Add permission management for API key routes and user context dependency --- backend/app/api/deps.py | 3 ++ backend/app/api/permissions.py | 70 ++++++++++++++++++++++++++++++ backend/app/api/routes/api_keys.py | 26 +++++++---- 3 files changed, 91 insertions(+), 8 deletions(-) create mode 100644 backend/app/api/permissions.py diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index 33c6bef4..dd705b09 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -205,6 +205,9 @@ def get_user_context( raise HTTPException(status_code=401, detail="Invalid Authorization format") +UserContextDep = Annotated[UserContext, Depends(get_user_context)] + + def verify_user_project_organization( db: SessionDep, current_user: CurrentUserOrg, diff --git a/backend/app/api/permissions.py b/backend/app/api/permissions.py new file mode 100644 index 00000000..21850cf2 --- /dev/null +++ b/backend/app/api/permissions.py @@ -0,0 +1,70 @@ +from enum import Enum +from typing import Annotated +from fastapi import Depends, HTTPException +from sqlmodel import Session + +from app.models import UserContext +from app.api.deps import UserContextDep, SessionDep + + +class Permission(str, Enum): + """Permission types for authorization checks""" + + SUPERUSER = "require_superuser" + REQUIRE_ORGANIZATION = "require_organization_id" + REQUIRE_PROJECT = "require_project_id" + + +def has_permission( + user_context: UserContext, + permission: Permission, + session: Session | None = None, +) -> bool: + """ + Check if the user_context has the specified permission. + + Args: + user_context: The authenticated user context + permission: The permission to check (Permission enum) + session: Database session (optional) + + Returns: + bool: True if user has permission, False otherwise + """ + match permission: + case Permission.SUPERUSER: + return user_context.is_superuser + case Permission.REQUIRE_ORGANIZATION: + return user_context.organization_id is not None + case Permission.REQUIRE_PROJECT: + return user_context.project_id is not None + case _: + return False + + +def require_permission(permission: Permission): + """ + Dependency factory for requiring specific permissions in FastAPI routes. + + Usage: + @app.get("/endpoint", dependencies=[Depends(require_permission(Permission.REQUIRE_ORGANIZATION))]) + def endpoint(user_context: Annotated[UserContext, Depends(get_user_context)]): + pass + """ + + def permission_checker( + user_context: UserContextDep, + session: SessionDep, + ): + if not has_permission(user_context, permission, session): + error_messages = { + Permission.SUPERUSER: "Insufficient permissions - require superuser access.", + Permission.REQUIRE_ORGANIZATION: "Insufficient permissions - require organization access.", + Permission.REQUIRE_PROJECT: "Insufficient permissions - require project access.", + } + raise HTTPException( + status_code=403, + detail=error_messages.get(permission, "Insufficient permissions"), + ) + + return permission_checker diff --git a/backend/app/api/routes/api_keys.py b/backend/app/api/routes/api_keys.py index 6a06b600..bc9cf79a 100644 --- a/backend/app/api/routes/api_keys.py +++ b/backend/app/api/routes/api_keys.py @@ -2,19 +2,25 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session -from app.api.deps import get_db, get_current_active_superuser, get_user_context +from app.api.deps import SessionDep, UserContextDep from app.crud.api_key import APIKeyCrud -from app.models import APIKeyPublic, APIKeyCreateResponse, User, UserContext +from app.models import APIKeyPublic, APIKeyCreateResponse from app.utils import APIResponse +from app.api.permissions import Permission, require_permission router = APIRouter(prefix="/api-keys", tags=["API Keys"]) -@router.post("/", response_model=APIResponse[APIKeyCreateResponse], status_code=201) +@router.post( + "/", + response_model=APIResponse[APIKeyCreateResponse], + status_code=201, + dependencies=[Depends(require_permission(Permission.SUPERUSER))], +) def create_api_key_route( project_id: int, - session: Session = Depends(get_db), - current_user: User = Depends(get_current_active_superuser), + current_user: UserContextDep, + session: SessionDep, ): """ Create a new API key for the current project. @@ -32,10 +38,14 @@ def create_api_key_route( return APIResponse.success_response(api_key) -@router.get("/", response_model=APIResponse[list[APIKeyPublic]]) +@router.get( + "/", + response_model=APIResponse[list[APIKeyPublic]], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) def list_api_keys_route( - session: Session = Depends(get_db), - current_user: User = Depends(get_current_active_superuser), + current_user: UserContextDep, + session: SessionDep, skip: int = Query(0, ge=0, description="Number of records to skip"), limit: int = Query(100, ge=1, le=100, description="Maximum records to return"), ): From 027565c2c5434df144dd9a3dd24c8fb3dc7edb04 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Mon, 6 Oct 2025 09:39:44 +0530 Subject: [PATCH 04/37] Refactor API key management: replace verify_api_key function with APIKeyManager class, add delete API key route, and update related CRUD operations --- .../d209cddac1fa_refactor_api_key_table.py | 1 - backend/app/api/deps.py | 10 +- backend/app/api/routes/api_keys.py | 24 ++- backend/app/crud/__init__.py | 2 +- backend/app/crud/api_key.py | 145 ++++++++++++------ backend/app/crud/onboarding.py | 4 +- backend/app/models/api_key.py | 3 +- 7 files changed, 127 insertions(+), 62 deletions(-) diff --git a/backend/app/alembic/versions/d209cddac1fa_refactor_api_key_table.py b/backend/app/alembic/versions/d209cddac1fa_refactor_api_key_table.py index 4e5428fa..2173394d 100644 --- a/backend/app/alembic/versions/d209cddac1fa_refactor_api_key_table.py +++ b/backend/app/alembic/versions/d209cddac1fa_refactor_api_key_table.py @@ -27,7 +27,6 @@ def upgrade(): "apikey", sa.Column("key_hash", sqlmodel.sql.sqltypes.AutoString(), nullable=False), ) - op.add_column("apikey", sa.Column("last_used_at", sa.DateTime(), nullable=True)) op.add_column("apikey", sa.Column("new_id", sa.Uuid(), nullable=True)) op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto;") diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index dd705b09..81b9ac9e 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -14,7 +14,7 @@ from app.core.db import engine from app.utils import APIResponse from app.crud.organization import validate_organization -from app.crud.api_key import verify_api_key +from app.crud.api_key import api_key_manager from app.models import ( TokenPayload, User, @@ -49,7 +49,7 @@ def get_current_user( """Authenticate user via API Key first, fallback to JWT token. Returns only User.""" if api_key: - api_key_record = verify_api_key(session, api_key) + api_key_record = api_key_manager.verify(session, api_key) if not api_key_record: raise HTTPException(status_code=401, detail="Invalid API Key") @@ -95,7 +95,7 @@ def get_current_user_org( organization_id = None api_key = request.headers.get("X-API-KEY") if api_key: - api_key_record = verify_api_key(session, api_key) + api_key_record = api_key_manager.verify(session, api_key) if api_key_record: validate_organization(session, api_key_record.organization_id) organization_id = api_key_record.organization_id @@ -116,7 +116,7 @@ def get_current_user_org_project( project_id = None if api_key: - api_key_record = verify_api_key(session, api_key) + api_key_record = api_key_manager.verify(session, api_key) if api_key_record: validate_organization(session, api_key_record.organization_id) organization_id = api_key_record.organization_id @@ -162,7 +162,7 @@ def get_user_context( Authorization logic should be handled in routes. """ if api_key: - api_key_record = verify_api_key(session, api_key) + api_key_record = api_key_manager.verify(session, api_key) if not api_key_record: raise HTTPException(status_code=401, detail="Invalid API Key") diff --git a/backend/app/api/routes/api_keys.py b/backend/app/api/routes/api_keys.py index bc9cf79a..b0198615 100644 --- a/backend/app/api/routes/api_keys.py +++ b/backend/app/api/routes/api_keys.py @@ -1,10 +1,9 @@ -from typing import Annotated +from uuid import UUID from fastapi import APIRouter, Depends, Query -from sqlmodel import Session from app.api.deps import SessionDep, UserContextDep from app.crud.api_key import APIKeyCrud -from app.models import APIKeyPublic, APIKeyCreateResponse +from app.models import APIKeyPublic, APIKeyCreateResponse, Message from app.utils import APIResponse from app.api.permissions import Permission, require_permission @@ -59,3 +58,22 @@ def list_api_keys_route( api_keys = crud.read_all(skip=skip, limit=limit) return APIResponse.success_response(api_keys) + + +@router.delete( + "/{key_id}", + response_model=APIResponse[Message], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def delete_api_key_route( + key_id: UUID, + current_user: UserContextDep, + session: SessionDep, +): + """ + Delete an API key by its ID. + """ + api_key_crud = APIKeyCrud(session=session, project_id=current_user.project_id) + api_key_crud.delete(key_id=key_id) + + return APIResponse.success_response(Message(message="API Key deleted successfully")) diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index 487495d9..38d55fc1 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -27,7 +27,7 @@ validate_project, ) -from .api_key import APIKeyCrud, generate_api_key +from .api_key import APIKeyCrud, api_key_manager from .credentials import ( set_creds_for_org, diff --git a/backend/app/crud/api_key.py b/backend/app/crud/api_key.py index 161b7234..ff657877 100644 --- a/backend/app/crud/api_key.py +++ b/backend/app/crud/api_key.py @@ -1,6 +1,7 @@ import logging import secrets -from typing import Optional, Tuple +from uuid import UUID +from typing import Tuple from sqlmodel import Session, select, and_ from fastapi import HTTPException @@ -8,69 +9,98 @@ from app.models import APIKey from app.crud import get_project_by_id +from app.core.util import now logger = logging.getLogger(__name__) -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - - -def verify_api_key(session: Session, raw_key: str) -> Optional[APIKey]: +class APIKeyManager: """ - Verify an API key by extracting the prefix and checking the hash. - Returns the APIKey record if valid, None otherwise. + Handles API key generation and verification using secure hashing. + + Key format: "ApiKey {22-char-prefix}{43-char-secret}" + - The prefix is stored plaintext for quick lookup + - Only the 43-char secret portion is hashed with bcrypt """ - try: - # Check format: "ApiKey {key_prefix}{random_key}" - if not raw_key.startswith("ApiKey "): - return None - # Extract the key part after "ApiKey " - key_part = raw_key[7:] # Remove "ApiKey " prefix + # Configuration constants + PREFIX_NAME = "ApiKey " + PREFIX_BYTES = 16 # Generates ~22 chars in urlsafe base64 + SECRET_BYTES = 32 # Generates ~43 chars in urlsafe base64 + PREFIX_LENGTH = 22 + KEY_LENGTH = 65 # Total length: 22 (prefix) + 43 (secret) + HASH_ALGORITHM = "bcrypt" - # Extract key_prefix (first 22 chars - urlsafe base64 of 16 bytes) - if len(key_part) < 22: - return None + pwd_context = CryptContext(schemes=[HASH_ALGORITHM], deprecated="auto") - key_prefix = key_part[:22] + @classmethod + def generate(cls) -> Tuple[str, str, str]: + """ + Generate a new API key with prefix and hashed value. - # Find API key by prefix - statement = select(APIKey).where( - and_( - APIKey.key_prefix == key_prefix, - APIKey.is_deleted.is_(False), - ) - ) - api_key_record = session.exec(statement).one_or_none() + Returns: + Tuple of (raw_key, key_prefix, key_hash) + """ + key_prefix = secrets.token_urlsafe(cls.PREFIX_BYTES) + secret_key = secrets.token_urlsafe(cls.SECRET_BYTES) - if not api_key_record: - return None + # Construct raw key: "ApiKey {prefix}{secret}" + raw_key = f"{cls.PREFIX_NAME}{key_prefix}{secret_key}" - # Verify hash - if pwd_context.verify(raw_key, api_key_record.key_hash): - return api_key_record + key_hash = cls.pwd_context.hash(secret_key) - return None + return raw_key, key_prefix, key_hash - except Exception as e: - logger.error( - f"[verify_api_key] Error verifying API key: {str(e)}", exc_info=True - ) - return None + @classmethod + def verify(cls, session: Session, raw_key: str) -> APIKey | None: + """ + Verify an API key by checking its prefix and hashed value. + Args: + session: Database session + raw_key: The raw API key to verify -def generate_api_key() -> Tuple[str, str, str]: - """ - Generate a new API key with key_prefix and hash. - """ - random_key = secrets.token_urlsafe(32) - key_prefix = secrets.token_urlsafe(16) + Returns: + The APIKey record if valid, None otherwise + """ + try: + expected_prefix = cls.PREFIX_NAME + if not raw_key.startswith(expected_prefix): + return None - raw_key = f"ApiKey {key_prefix}{random_key}" + key_part = raw_key[len(expected_prefix) :] + if len(key_part) != cls.KEY_LENGTH: + return None - key_hash = pwd_context.hash(raw_key) + key_prefix = key_part[: cls.PREFIX_LENGTH] + secret_key = key_part[cls.PREFIX_LENGTH :] - return raw_key, key_prefix, key_hash + statement = select(APIKey).where( + and_( + APIKey.key_prefix == key_prefix, + APIKey.is_deleted.is_(False), + ) + ) + api_key_record = session.exec(statement).one_or_none() + + if not api_key_record: + return None + + # Verify only the secret portion (43 chars) against the stored hash + if cls.pwd_context.verify(secret_key, api_key_record.key_hash): + return api_key_record + + return None + + except Exception as e: + logger.error( + f"[APIKeyManager.verify] Error verifying API key: {str(e)}", + exc_info=True, + ) + return None + + +api_key_manager = APIKeyManager() class APIKeyCrud: @@ -82,13 +112,13 @@ def __init__(self, session: Session, project_id: int): self.session = session self.project_id = project_id - def read_one(self, key_prefix: str) -> Optional[APIKey]: + def read_one(self, key_id: UUID) -> APIKey | None: """ Retrieve a single non-deleted API key by its key_prefix. """ statement = select(APIKey).where( and_( - APIKey.key_prefix == key_prefix, + APIKey.id == key_id, APIKey.project_id == self.project_id, APIKey.is_deleted.is_(False), ) @@ -117,7 +147,7 @@ def create(self, user_id: int) -> Tuple[str, APIKey]: Create a new API key for the project. """ try: - raw_key, key_prefix, key_hash = generate_api_key() + raw_key, key_prefix, key_hash = api_key_manager.generate() project = get_project_by_id( session=self.session, project_id=self.project_id @@ -152,3 +182,22 @@ def create(self, user_id: int) -> Tuple[str, APIKey]: raise HTTPException( status_code=500, detail=f"Failed to create API key: {str(e)}" ) + + def delete(self, key_id: UUID) -> None: + """ + Soft delete an API key by marking it as deleted. + """ + api_key = self.read_one(key_id) + if not api_key: + raise HTTPException(status_code=404, detail="API Key not found") + + api_key.is_deleted = True + api_key.deleted_at = now() + self.session.add(api_key) + self.session.commit() + self.session.refresh(api_key) + + logger.info( + f"[APIKeyCrud.delete_api_key] API key deleted successfully | " + f"{{'api_key_id': '{api_key.id}', 'project_id': {self.project_id}}}" + ) diff --git a/backend/app/crud/onboarding.py b/backend/app/crud/onboarding.py index e77894ba..26777d42 100644 --- a/backend/app/crud/onboarding.py +++ b/backend/app/crud/onboarding.py @@ -4,7 +4,7 @@ from app.core.security import encrypt_api_key, encrypt_credentials, get_password_hash from app.crud import ( - generate_api_key, + api_key_manager, get_organization_by_name, get_project_by_name, get_user_by_email, @@ -90,7 +90,7 @@ def onboard_project( session.add(user) session.flush() - raw_key, key_prefix, key_hash = generate_api_key() + raw_key, key_prefix, key_hash = api_key_manager.generate() api_key = APIKey( key_prefix=key_prefix, diff --git a/backend/app/models/api_key.py b/backend/app/models/api_key.py index e3a0c136..d90c9d9d 100644 --- a/backend/app/models/api_key.py +++ b/backend/app/models/api_key.py @@ -38,9 +38,8 @@ class APIKey(APIKeyBase, table=True): key_prefix: str = Field( unique=True, index=True, nullable=False ) # Unique identifier from the key - key_hash: str = Field(nullable=False) # bcrypt hash of the full key + key_hash: str = Field(nullable=False) # bcrypt hash of the secret portion - last_used_at: datetime | None = Field(default=None, nullable=True) inserted_at: datetime = Field(default_factory=now, nullable=False) updated_at: datetime = Field(default_factory=now, nullable=False) is_deleted: bool = Field(default=False, nullable=False) From b388cad28e4b3e16e579e4d315737644e61f8090 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Mon, 6 Oct 2025 11:31:55 +0530 Subject: [PATCH 05/37] Enhance API key management: add support for old key format in APIKeyManager, update key prefix extraction logic, and adjust seed data creation for API keys. --- backend/app/crud/api_key.py | 52 ++++++++++++++++++++++-------- backend/app/models/api_key.py | 1 - backend/app/seed_data/seed_data.py | 5 ++- 3 files changed, 41 insertions(+), 17 deletions(-) diff --git a/backend/app/crud/api_key.py b/backend/app/crud/api_key.py index ff657877..3e2f0971 100644 --- a/backend/app/crud/api_key.py +++ b/backend/app/crud/api_key.py @@ -17,10 +17,7 @@ class APIKeyManager: """ Handles API key generation and verification using secure hashing. - - Key format: "ApiKey {22-char-prefix}{43-char-secret}" - - The prefix is stored plaintext for quick lookup - - Only the 43-char secret portion is hashed with bcrypt + Supports Backwards compatibility with old key format. """ # Configuration constants @@ -51,10 +48,44 @@ def generate(cls) -> Tuple[str, str, str]: return raw_key, key_prefix, key_hash + @classmethod + def _extract_key_parts(cls, raw_key: str) -> Tuple[str, str] | None: + """ + Extract prefix and secret from an API key based on its format. + + Supports: + - New format: "ApiKey {22-char-prefix}{43-char-secret}" + - Old format: "ApiKey {12-char-prefix}{31-char-secret}" + + Returns: + Tuple[str, str] -> (key_prefix, secret_to_verify) + or None if invalid + """ + if not raw_key.startswith(cls.PREFIX_NAME): + return None + + key_part = raw_key[len(cls.PREFIX_NAME) :] + + if len(key_part) == cls.KEY_LENGTH: + key_prefix = key_part[: cls.PREFIX_LENGTH] + secret_key = key_part[cls.PREFIX_LENGTH :] + return key_prefix, secret_key + + old_key_length = 43 + old_prefix_length = 12 + if len(key_part) == old_key_length: + key_prefix = key_part[:old_prefix_length] + secret_key = key_part[old_prefix_length:] + return key_prefix, secret_key + + # Invalid format + return None + @classmethod def verify(cls, session: Session, raw_key: str) -> APIKey | None: """ Verify an API key by checking its prefix and hashed value. + Supports both old (43 chars) and new ("ApiKey " + 65 chars) formats. Args: session: Database session @@ -64,16 +95,12 @@ def verify(cls, session: Session, raw_key: str) -> APIKey | None: The APIKey record if valid, None otherwise """ try: - expected_prefix = cls.PREFIX_NAME - if not raw_key.startswith(expected_prefix): - return None + key_parts = cls._extract_key_parts(raw_key) - key_part = raw_key[len(expected_prefix) :] - if len(key_part) != cls.KEY_LENGTH: + if not key_parts: return None - key_prefix = key_part[: cls.PREFIX_LENGTH] - secret_key = key_part[cls.PREFIX_LENGTH :] + key_prefix, secret = key_parts statement = select(APIKey).where( and_( @@ -86,8 +113,7 @@ def verify(cls, session: Session, raw_key: str) -> APIKey | None: if not api_key_record: return None - # Verify only the secret portion (43 chars) against the stored hash - if cls.pwd_context.verify(secret_key, api_key_record.key_hash): + if cls.pwd_context.verify(secret, api_key_record.key_hash): return api_key_record return None diff --git a/backend/app/models/api_key.py b/backend/app/models/api_key.py index d90c9d9d..1da56382 100644 --- a/backend/app/models/api_key.py +++ b/backend/app/models/api_key.py @@ -21,7 +21,6 @@ class APIKeyBase(SQLModel): class APIKeyPublic(APIKeyBase): id: UUID key_prefix: str # Expose key_id for display (partial key identifier) - last_used_at: datetime | None inserted_at: datetime updated_at: datetime diff --git a/backend/app/seed_data/seed_data.py b/backend/app/seed_data/seed_data.py index 3ba195d3..6362aa1f 100644 --- a/backend/app/seed_data/seed_data.py +++ b/backend/app/seed_data/seed_data.py @@ -193,13 +193,12 @@ def create_api_key(session: Session, api_key_data_raw: dict) -> APIKey: # Extract the key_prefix (first 16 characters after "ApiKey ") key_portion = raw_key[7:] # Remove "ApiKey " prefix - key_prefix = key_portion[:8] # First 8 characters as prefix + key_prefix = key_portion[:12] # First 12 characters as prefix - # Hash the full raw key from passlib.context import CryptContext pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - key_hash = pwd_context.hash(raw_key) + key_hash = pwd_context.hash(key_portion[12:]) api_key = APIKey( organization_id=organization.id, From 87f9b70889465bc772126a535cad24c9ea7c2b7e Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Mon, 6 Oct 2025 11:50:30 +0530 Subject: [PATCH 06/37] Fix API key generation: ensure exact lengths for prefix and secret key, and update comments for clarity. --- backend/app/crud/api_key.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/backend/app/crud/api_key.py b/backend/app/crud/api_key.py index 3e2f0971..6620d9d6 100644 --- a/backend/app/crud/api_key.py +++ b/backend/app/crud/api_key.py @@ -22,8 +22,8 @@ class APIKeyManager: # Configuration constants PREFIX_NAME = "ApiKey " - PREFIX_BYTES = 16 # Generates ~22 chars in urlsafe base64 - SECRET_BYTES = 32 # Generates ~43 chars in urlsafe base64 + PREFIX_BYTES = 16 # Generates 22 chars in urlsafe base64 + SECRET_BYTES = 32 # Generates 43 chars in urlsafe base64 PREFIX_LENGTH = 22 KEY_LENGTH = 65 # Total length: 22 (prefix) + 43 (secret) HASH_ALGORITHM = "bcrypt" @@ -34,12 +34,19 @@ class APIKeyManager: def generate(cls) -> Tuple[str, str, str]: """ Generate a new API key with prefix and hashed value. + Ensures exact lengths: prefix=22 chars, secret=43 chars. Returns: Tuple of (raw_key, key_prefix, key_hash) """ - key_prefix = secrets.token_urlsafe(cls.PREFIX_BYTES) - secret_key = secrets.token_urlsafe(cls.SECRET_BYTES) + # Generate tokens and ensure exact length + secret_length = cls.KEY_LENGTH - cls.PREFIX_LENGTH + key_prefix = secrets.token_urlsafe(cls.PREFIX_BYTES)[: cls.PREFIX_LENGTH].ljust( + cls.PREFIX_LENGTH, "A" + ) + secret_key = secrets.token_urlsafe(cls.SECRET_BYTES)[:secret_length].ljust( + secret_length, "A" + ) # Construct raw key: "ApiKey {prefix}{secret}" raw_key = f"{cls.PREFIX_NAME}{key_prefix}{secret_key}" From 17063f158abfb7be70d13e16f6238e912f4dc7bf Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Mon, 6 Oct 2025 13:42:22 +0530 Subject: [PATCH 07/37] Add migration script for API keys: convert encrypted keys to hashed format and update database schema --- backend/app/alembic/migrate_api_key.py | 193 ++++++++++++++++++ .../d209cddac1fa_refactor_api_key_table.py | 27 ++- 2 files changed, 215 insertions(+), 5 deletions(-) create mode 100644 backend/app/alembic/migrate_api_key.py diff --git a/backend/app/alembic/migrate_api_key.py b/backend/app/alembic/migrate_api_key.py new file mode 100644 index 00000000..e59f9b3a --- /dev/null +++ b/backend/app/alembic/migrate_api_key.py @@ -0,0 +1,193 @@ +""" +Migration script to convert encrypted API keys to hashed format. + +This script: +1. Decrypts existing API keys from the old encrypted format +2. Extracts the prefix and secret from the decrypted keys +3. Hashes the secret using bcrypt +4. Generates UUID4 for the new primary key +5. Stores the prefix, hash, and UUID in the new format for backward compatibility + +The format is: "ApiKey {12-char-prefix}{31-char-secret}" (total 43 chars) +""" + +import logging +import uuid +from sqlalchemy.orm import Session +from sqlalchemy import text +from passlib.context import CryptContext + +from app.core.security import decrypt_api_key + +logger = logging.getLogger(__name__) + +# Use the same hash algorithm as APIKeyManager +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +# Old format constants +OLD_PREFIX_NAME = "ApiKey " +OLD_PREFIX_LENGTH = 12 +OLD_SECRET_LENGTH = 31 +OLD_KEY_LENGTH = 43 # Total: 12 + 31 + + +def migrate_api_keys(session: Session, generate_uuid: bool = False) -> None: + """ + Migrate all existing API keys from encrypted format to hashed format. + + This function: + 1. Fetches all API keys with the old 'key' column + 2. Decrypts each key + 3. Extracts prefix and secret + 4. Hashes the secret + 5. Generates UUID4 for new_id column if generate_uuid is True + 6. Updates key_prefix, key_hash, and optionally new_id columns + + Args: + session: SQLAlchemy database session + generate_uuid: Whether to generate and set UUID for new_id column + """ + logger.info("[migrate_api_keys] Starting API key migration from encrypted to hashed format") + + try: + # Fetch all API keys that have the old 'key' column + result = session.execute( + text("SELECT id, key FROM apikey WHERE key IS NOT NULL") + ) + api_keys = result.fetchall() + + if not api_keys: + logger.info("[migrate_api_keys] No API keys found to migrate") + return + + logger.info(f"[migrate_api_keys] Found {len(api_keys)} API keys to migrate") + + migrated_count = 0 + failed_count = 0 + + for row in api_keys: + key_id = row[0] + encrypted_key = row[1] + + try: + # Decrypt the API key + decrypted_key = decrypt_api_key(encrypted_key) + + # Validate format + if not decrypted_key.startswith(OLD_PREFIX_NAME): + logger.error( + f"[migrate_api_keys] Invalid key format for ID {key_id}: " + f"does not start with '{OLD_PREFIX_NAME}'" + ) + failed_count += 1 + continue + + # Extract the key part (after "ApiKey ") + key_part = decrypted_key[len(OLD_PREFIX_NAME):] + + if len(key_part) != OLD_KEY_LENGTH: + logger.error( + f"[migrate_api_keys] Invalid key length for ID {key_id}: " + f"expected {OLD_KEY_LENGTH}, got {len(key_part)}" + ) + failed_count += 1 + continue + + # Extract prefix and secret + key_prefix = key_part[:OLD_PREFIX_LENGTH] + secret_key = key_part[OLD_PREFIX_LENGTH:] + + # Hash the secret + key_hash = pwd_context.hash(secret_key) + + # Generate UUID if requested + if generate_uuid: + new_uuid = uuid.uuid4() + # Update the record with prefix, hash, and UUID + session.execute( + text( + "UPDATE apikey SET key_prefix = :prefix, key_hash = :hash, new_id = :new_id " + "WHERE id = :id" + ), + {"prefix": key_prefix, "hash": key_hash, "new_id": new_uuid, "id": key_id} + ) + else: + # Update the record with prefix and hash only + session.execute( + text( + "UPDATE apikey SET key_prefix = :prefix, key_hash = :hash " + "WHERE id = :id" + ), + {"prefix": key_prefix, "hash": key_hash, "id": key_id} + ) + + migrated_count += 1 + logger.info( + f"[migrate_api_keys] Successfully migrated key ID {key_id} " + f"with prefix {key_prefix[:4]}..." + ) + + except Exception as e: + logger.error( + f"[migrate_api_keys] Failed to migrate key ID {key_id}: {str(e)}", + exc_info=True + ) + failed_count += 1 + continue + + logger.info( + f"[migrate_api_keys] Migration completed: " + f"{migrated_count} successful, {failed_count} failed" + ) + + except Exception as e: + logger.error( + f"[migrate_api_keys] Fatal error during migration: {str(e)}", + exc_info=True + ) + raise + + +def verify_migration(session: Session) -> bool: + """ + Verify that all API keys have been migrated successfully. + + Args: + session: SQLAlchemy database session + + Returns: + bool: True if all keys are migrated, False otherwise + """ + try: + # Check for any keys with NULL key_prefix or key_hash + result = session.execute( + text( + "SELECT COUNT(*) FROM apikey " + "WHERE key_prefix IS NULL OR key_hash IS NULL" + ) + ) + null_count = result.scalar() + + if null_count > 0: + logger.warning( + f"[verify_migration] Found {null_count} API keys with NULL " + "key_prefix or key_hash" + ) + return False + + # Check total count + result = session.execute(text("SELECT COUNT(*) FROM apikey")) + total_count = result.scalar() + + logger.info( + f"[verify_migration] All {total_count} API keys have been " + "successfully migrated" + ) + return True + + except Exception as e: + logger.error( + f"[verify_migration] Error verifying migration: {str(e)}", + exc_info=True + ) + return False diff --git a/backend/app/alembic/versions/d209cddac1fa_refactor_api_key_table.py b/backend/app/alembic/versions/d209cddac1fa_refactor_api_key_table.py index 2173394d..20d66172 100644 --- a/backend/app/alembic/versions/d209cddac1fa_refactor_api_key_table.py +++ b/backend/app/alembic/versions/d209cddac1fa_refactor_api_key_table.py @@ -8,6 +8,9 @@ from alembic import op import sqlalchemy as sa import sqlmodel.sql.sqltypes +from sqlalchemy.orm import Session + +from app.alembic.migrate_api_key import migrate_api_keys, verify_migration # revision identifiers, used by Alembic. @@ -19,25 +22,39 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### + # Step 1: Add new columns as nullable to allow migration op.add_column( "apikey", - sa.Column("key_prefix", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("key_prefix", sqlmodel.sql.sqltypes.AutoString(), nullable=True), ) op.add_column( "apikey", - sa.Column("key_hash", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("key_hash", sqlmodel.sql.sqltypes.AutoString(), nullable=True), ) + # Step 2: Add UUID column before migration op.add_column("apikey", sa.Column("new_id", sa.Uuid(), nullable=True)) - op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto;") - op.execute("UPDATE apikey SET new_id = gen_random_uuid();") - # Replace old PK + # Step 3: Migrate existing encrypted keys to the new hashed format and generate UUIDs + bind = op.get_bind() + session = Session(bind=bind) + migrate_api_keys(session, generate_uuid=True) + + # Step 4: Verify migration was successful + if not verify_migration(session): + raise Exception("API key migration verification failed. Please check the logs.") + + # Step 5: Make the columns non-nullable after migration + op.alter_column("apikey", "key_prefix", nullable=False) + op.alter_column("apikey", "key_hash", nullable=False) + + # Step 6: Replace old PK with UUID-based PK op.drop_constraint("apikey_pkey", "apikey", type_="primary") op.drop_column("apikey", "id") op.alter_column("apikey", "new_id", new_column_name="id", nullable=False) op.create_primary_key("apikey_pkey", "apikey", ["id"]) + # Step 7: Update indexes and drop old key column op.drop_index("ix_apikey_key", table_name="apikey") op.create_index(op.f("ix_apikey_key_prefix"), "apikey", ["key_prefix"], unique=True) op.drop_column("apikey", "key") From 338be0ed933146e0fe80d4a2bba88cf17f350aae Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Mon, 6 Oct 2025 13:45:18 +0530 Subject: [PATCH 08/37] precommit --- backend/app/alembic/migrate_api_key.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/backend/app/alembic/migrate_api_key.py b/backend/app/alembic/migrate_api_key.py index e59f9b3a..ab48901d 100644 --- a/backend/app/alembic/migrate_api_key.py +++ b/backend/app/alembic/migrate_api_key.py @@ -47,7 +47,9 @@ def migrate_api_keys(session: Session, generate_uuid: bool = False) -> None: session: SQLAlchemy database session generate_uuid: Whether to generate and set UUID for new_id column """ - logger.info("[migrate_api_keys] Starting API key migration from encrypted to hashed format") + logger.info( + "[migrate_api_keys] Starting API key migration from encrypted to hashed format" + ) try: # Fetch all API keys that have the old 'key' column @@ -83,7 +85,7 @@ def migrate_api_keys(session: Session, generate_uuid: bool = False) -> None: continue # Extract the key part (after "ApiKey ") - key_part = decrypted_key[len(OLD_PREFIX_NAME):] + key_part = decrypted_key[len(OLD_PREFIX_NAME) :] if len(key_part) != OLD_KEY_LENGTH: logger.error( @@ -109,7 +111,12 @@ def migrate_api_keys(session: Session, generate_uuid: bool = False) -> None: "UPDATE apikey SET key_prefix = :prefix, key_hash = :hash, new_id = :new_id " "WHERE id = :id" ), - {"prefix": key_prefix, "hash": key_hash, "new_id": new_uuid, "id": key_id} + { + "prefix": key_prefix, + "hash": key_hash, + "new_id": new_uuid, + "id": key_id, + }, ) else: # Update the record with prefix and hash only @@ -118,7 +125,7 @@ def migrate_api_keys(session: Session, generate_uuid: bool = False) -> None: "UPDATE apikey SET key_prefix = :prefix, key_hash = :hash " "WHERE id = :id" ), - {"prefix": key_prefix, "hash": key_hash, "id": key_id} + {"prefix": key_prefix, "hash": key_hash, "id": key_id}, ) migrated_count += 1 @@ -130,7 +137,7 @@ def migrate_api_keys(session: Session, generate_uuid: bool = False) -> None: except Exception as e: logger.error( f"[migrate_api_keys] Failed to migrate key ID {key_id}: {str(e)}", - exc_info=True + exc_info=True, ) failed_count += 1 continue @@ -142,8 +149,7 @@ def migrate_api_keys(session: Session, generate_uuid: bool = False) -> None: except Exception as e: logger.error( - f"[migrate_api_keys] Fatal error during migration: {str(e)}", - exc_info=True + f"[migrate_api_keys] Fatal error during migration: {str(e)}", exc_info=True ) raise @@ -187,7 +193,6 @@ def verify_migration(session: Session) -> bool: except Exception as e: logger.error( - f"[verify_migration] Error verifying migration: {str(e)}", - exc_info=True + f"[verify_migration] Error verifying migration: {str(e)}", exc_info=True ) return False From 1e90f01f95d8554ce45fd18b588d553c099f7d70 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Mon, 6 Oct 2025 15:41:57 +0530 Subject: [PATCH 09/37] Refactor API key handling in tests to use AuthContext - Removed direct usage of APIKeyPublic in tests and replaced with AuthContext for better encapsulation of authentication details. - Updated test cases across various modules (e.g., test_api_key, test_assistants, test_creds, etc.) to utilize the new AuthContext structure. - Simplified the retrieval of API keys and user context in test setups, enhancing readability and maintainability. - Removed redundant API key creation logic from test utilities and centralized it within the AuthContext. - Ensured all tests still pass with the new structure, maintaining functionality while improving code organization. --- backend/app/seed_data/seed_data.json | 2 +- .../collections/test_collection_info.py | 33 ++--- .../collections/test_create_collections.py | 10 +- .../documents/test_route_document_upload.py | 4 +- backend/app/tests/api/routes/test_api_key.py | 117 --------------- .../app/tests/api/routes/test_assistants.py | 48 +++--- backend/app/tests/api/routes/test_creds.py | 34 ++--- .../api/routes/test_doc_transformation_job.py | 37 ++--- .../api/routes/test_openai_conversation.py | 30 ++-- backend/app/tests/conftest.py | 36 +++-- .../doctransformer/test_service/conftest.py | 8 +- backend/app/tests/crud/test_api_key.py | 117 --------------- backend/app/tests/crud/test_onboarding.py | 1 - backend/app/tests/utils/auth.py | 137 ++++++++++++++++++ backend/app/tests/utils/collection.py | 9 +- backend/app/tests/utils/document.py | 7 +- backend/app/tests/utils/test_data.py | 20 +-- backend/app/tests/utils/utils.py | 36 ++--- 18 files changed, 278 insertions(+), 408 deletions(-) create mode 100644 backend/app/tests/utils/auth.py diff --git a/backend/app/seed_data/seed_data.json b/backend/app/seed_data/seed_data.json index 3427375b..3deabfc3 100644 --- a/backend/app/seed_data/seed_data.json +++ b/backend/app/seed_data/seed_data.json @@ -48,7 +48,7 @@ "organization_name": "Project Tech4dev", "user_email": "{{ADMIN_EMAIL}}", "project_name": "Dalgo", - "api_key": "ApiKey Px8y47B6roJHin1lWLkR88eiDrFdXSJRZmFQazzai8j9", + "api_key": "ApiKey Px8y47B6roJHin1lWLkR88eiDrFdXSJRZmFQazzai8j", "is_deleted": false, "deleted_at": null } diff --git a/backend/app/tests/api/routes/collections/test_collection_info.py b/backend/app/tests/api/routes/collections/test_collection_info.py index 5747f790..66ebb01e 100644 --- a/backend/app/tests/api/routes/collections/test_collection_info.py +++ b/backend/app/tests/api/routes/collections/test_collection_info.py @@ -4,22 +4,22 @@ from sqlmodel import Session from app.core.config import settings from app.models import Collection -from app.tests.utils.utils import get_user_from_api_key from app.models.collection import CollectionStatus +from app.tests.utils.auth import AuthContext def create_collection( db, - user, + user_api_key: AuthContext, status: CollectionStatus = CollectionStatus.processing, with_llm: bool = False, ): now = datetime.now(timezone.utc) collection = Collection( id=uuid4(), - owner_id=user.user_id, - organization_id=user.organization_id, - project_id=user.project_id, + owner_id=user_api_key.user_id, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, status=status, updated_at=now, ) @@ -34,11 +34,10 @@ def create_collection( def test_collection_info_processing( - db: Session, client: TestClient, user_api_key_header + db: Session, client: TestClient, user_api_key: AuthContext ): - headers = user_api_key_header - user = get_user_from_api_key(db, headers) - collection = create_collection(db, user, status=CollectionStatus.processing) + headers = {"X-API-KEY": user_api_key.key} + collection = create_collection(db, user_api_key, status=CollectionStatus.processing) response = client.post( f"{settings.API_V1_STR}/collections/info/{collection.id}", @@ -55,12 +54,11 @@ def test_collection_info_processing( def test_collection_info_successful( - db: Session, client: TestClient, user_api_key_header + db: Session, client: TestClient, user_api_key: AuthContext ): - headers = user_api_key_header - user = get_user_from_api_key(db, headers) + headers = {"X-API-KEY": user_api_key.key} collection = create_collection( - db, user, status=CollectionStatus.successful, with_llm=True + db, user_api_key, status=CollectionStatus.successful, with_llm=True ) response = client.post( @@ -77,10 +75,11 @@ def test_collection_info_successful( assert data["llm_service_name"] == "gpt-4o" -def test_collection_info_failed(db: Session, client: TestClient, user_api_key_header): - headers = user_api_key_header - user = get_user_from_api_key(db, headers) - collection = create_collection(db, user, status=CollectionStatus.failed) +def test_collection_info_failed( + db: Session, client: TestClient, user_api_key: AuthContext +): + headers = {"X-API-KEY": user_api_key.key} + collection = create_collection(db, user_api_key, status=CollectionStatus.failed) response = client.post( f"{settings.API_V1_STR}/collections/info/{collection.id}", diff --git a/backend/app/tests/api/routes/collections/test_create_collections.py b/backend/app/tests/api/routes/collections/test_create_collections.py index 22764df4..ff1e00af 100644 --- a/backend/app/tests/api/routes/collections/test_create_collections.py +++ b/backend/app/tests/api/routes/collections/test_create_collections.py @@ -6,13 +6,12 @@ from fastapi.testclient import TestClient from unittest.mock import patch -from app.models import APIKeyPublic from app.core.config import settings from app.tests.utils.document import DocumentStore -from app.tests.utils.utils import get_user_from_api_key from app.crud.collection import CollectionCrud from app.models.collection import CollectionStatus from app.tests.utils.openai import get_mock_openai_client_with_vector_store +from app.tests.utils.auth import AuthContext @pytest.fixture(autouse=True) @@ -49,7 +48,7 @@ def test_create_collection_success( mock_get_openai_client, client: TestClient, db: Session, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): store = DocumentStore(db, project_id=user_api_key.project_id) documents = store.fill(self._n_documents) @@ -81,8 +80,7 @@ def test_create_collection_success( # Confirm collection metadata in DB collection_id = UUID(metadata["key"]) - user = get_user_from_api_key(db, headers) - collection = CollectionCrud(db, user.user_id).read_one(collection_id) + collection = CollectionCrud(db, user_api_key.user_id).read_one(collection_id) info_response = client.post( f"{settings.API_V1_STR}/collections/info/{collection_id}", @@ -92,6 +90,6 @@ def test_create_collection_success( info_data = info_response.json()["data"] assert collection.status == CollectionStatus.successful.value - assert collection.owner_id == user.user_id + assert collection.owner_id == user_api_key.user_id assert collection.llm_service_id is not None assert collection.llm_service_name == "gpt-4o" diff --git a/backend/app/tests/api/routes/documents/test_route_document_upload.py b/backend/app/tests/api/routes/documents/test_route_document_upload.py index 4b73a451..de6d6dd1 100644 --- a/backend/app/tests/api/routes/documents/test_route_document_upload.py +++ b/backend/app/tests/api/routes/documents/test_route_document_upload.py @@ -10,7 +10,6 @@ from sqlmodel import Session, select from fastapi.testclient import TestClient -from app.models import APIKeyPublic from app.core.cloud import AmazonCloudStorageClient from app.core.config import settings from app.models import Document @@ -19,6 +18,7 @@ WebCrawler, httpx_to_standard, ) +from app.tests.utils.auth import AuthContext class WebUploader(WebCrawler): @@ -77,7 +77,7 @@ def route(): @pytest.fixture -def uploader(client: TestClient, user_api_key: APIKeyPublic): +def uploader(client: TestClient, user_api_key: AuthContext): return WebUploader(client, user_api_key) diff --git a/backend/app/tests/api/routes/test_api_key.py b/backend/app/tests/api/routes/test_api_key.py index c4e63954..e69de29b 100644 --- a/backend/app/tests/api/routes/test_api_key.py +++ b/backend/app/tests/api/routes/test_api_key.py @@ -1,117 +0,0 @@ -from fastapi.testclient import TestClient -from sqlmodel import Session - -from app.main import app -from app.models import APIKey -from app.core.config import settings -from app.tests.utils.utils import get_non_existent_id -from app.tests.utils.user import create_random_user -from app.tests.utils.test_data import create_test_api_key, create_test_project - -client = TestClient(app) - - -def test_create_api_key(db: Session, superuser_token_headers: dict[str, str]): - user = create_random_user(db) - project = create_test_project(db) - - response = client.post( - f"{settings.API_V1_STR}/apikeys", - params={"project_id": project.id, "user_id": user.id}, - headers=superuser_token_headers, - ) - assert response.status_code == 200 - data = response.json() - assert data["success"] is True - assert "id" in data["data"] - assert "key" in data["data"] - assert data["data"]["organization_id"] == project.organization_id - assert data["data"]["user_id"] == user.id - - -def test_create_duplicate_api_key(db: Session, superuser_token_headers: dict[str, str]): - user = create_random_user(db) - project = create_test_project(db) - - client.post( - f"{settings.API_V1_STR}/apikeys", - params={"project_id": project.id, "user_id": user.id}, - headers=superuser_token_headers, - ) - response = client.post( - f"{settings.API_V1_STR}/apikeys", - params={"project_id": project.id, "user_id": user.id}, - headers=superuser_token_headers, - ) - assert response.status_code == 400 - assert "API Key already exists" in response.json()["error"] - - -def test_list_api_keys(db: Session, superuser_token_headers: dict[str, str]): - api_key = create_test_api_key(db) - - response = client.get( - f"{settings.API_V1_STR}/apikeys", - params={"project_id": api_key.project_id}, - headers=superuser_token_headers, - ) - assert response.status_code == 200 - data = response.json() - assert data["success"] is True - assert isinstance(data["data"], list) - assert len(data["data"]) > 0 - - first_key = data["data"][0] - assert first_key["organization_id"] == api_key.organization_id - assert first_key["user_id"] == api_key.user_id - - -def test_get_api_key(db: Session, superuser_token_headers: dict[str, str]): - api_key = create_test_api_key(db) - - response = client.get( - f"{settings.API_V1_STR}/apikeys/{api_key.id}", - headers=superuser_token_headers, - ) - assert response.status_code == 200 - data = response.json() - assert data["success"] is True - assert data["data"]["id"] == api_key.id - assert data["data"]["organization_id"] == api_key.organization_id - assert data["data"]["user_id"] == api_key.user_id - - -def test_get_nonexistent_api_key(db: Session, superuser_token_headers: dict[str, str]): - api_key_id = get_non_existent_id(db, APIKey) - response = client.get( - f"{settings.API_V1_STR}/apikeys/{api_key_id}", - headers=superuser_token_headers, - ) - assert response.status_code == 404 - assert "API Key does not exist" in response.json()["error"] - - -def test_revoke_api_key(db: Session, superuser_token_headers: dict[str, str]): - api_key = create_test_api_key(db) - - response = client.delete( - f"{settings.API_V1_STR}/apikeys/{api_key.id}", - headers=superuser_token_headers, - ) - assert response.status_code == 200 - data = response.json() - assert data["success"] is True - assert "API key revoked successfully" in data["data"]["message"] - - -def test_revoke_nonexistent_api_key( - db: Session, superuser_token_headers: dict[str, str] -): - api_key_id = get_non_existent_id(db, APIKey) - - response = client.delete( - f"{settings.API_V1_STR}/apikeys/{api_key_id}", - headers=superuser_token_headers, - ) - assert response.status_code == 404 - assert "API key not found or already deleted" in response.json()["error"] diff --git a/backend/app/tests/api/routes/test_assistants.py b/backend/app/tests/api/routes/test_assistants.py index 8d236696..e5f73fe9 100644 --- a/backend/app/tests/api/routes/test_assistants.py +++ b/backend/app/tests/api/routes/test_assistants.py @@ -4,9 +4,9 @@ from fastapi import HTTPException from fastapi.testclient import TestClient from unittest.mock import patch -from app.crud.api_key import get_api_keys_by_project from app.tests.utils.openai import mock_openai_assistant from app.tests.utils.utils import get_assistant +from app.tests.utils.auth import AuthContext @pytest.fixture @@ -30,7 +30,7 @@ def assistant_id(): def test_ingest_assistant_success( mock_fetch_assistant, client: TestClient, - user_api_key_header: dict[str, str], + user_api_key: AuthContext, ): """Test successful assistant ingestion from OpenAI.""" mock_assistant = mock_openai_assistant() @@ -39,7 +39,7 @@ def test_ingest_assistant_success( response = client.post( f"/api/v1/assistant/{mock_assistant.id}/ingest", - headers=user_api_key_header, + headers={"X-API-KEY": f"{user_api_key.key}"}, ) assert response.status_code == 201 @@ -53,7 +53,7 @@ def test_create_assistant_success( mock_verify_vector_ids, client: TestClient, assistant_create_payload: dict, - user_api_key_header: dict, + user_api_key: AuthContext, ): """Test successful assistant creation with OpenAI vector store ID verification.""" @@ -62,7 +62,7 @@ def test_create_assistant_success( response = client.post( "/api/v1/assistant", json=assistant_create_payload, - headers=user_api_key_header, + headers={"X-API-KEY": f"{user_api_key.key}"}, ) assert response.status_code == 201 @@ -92,7 +92,7 @@ def test_create_assistant_invalid_vector_store( mock_verify_vector_ids, client: TestClient, assistant_create_payload: dict, - user_api_key_header: dict, + user_api_key: AuthContext, ): """Test failure when one or more vector store IDs are invalid.""" @@ -106,7 +106,7 @@ def test_create_assistant_invalid_vector_store( response = client.post( "/api/v1/assistant", json=payload, - headers=user_api_key_header, + headers={"X-API-KEY": f"{user_api_key.key}"}, ) assert response.status_code == 400 @@ -117,6 +117,7 @@ def test_create_assistant_invalid_vector_store( def test_update_assistant_success( client: TestClient, db: Session, + user_api_key: AuthContext, ): """Test successful assistant update.""" update_payload = { @@ -127,13 +128,12 @@ def test_update_assistant_success( "max_num_results": 5, } - assistant = get_assistant(db) - api_key = get_api_keys_by_project(db, assistant.project_id)[0] + assistant = get_assistant(db, project_id=user_api_key.project_id) response = client.patch( f"/api/v1/assistant/{assistant.assistant_id}", json=update_payload, - headers={"X-API-KEY": f"{api_key.key}"}, + headers={"X-API-KEY": f"{user_api_key.key}"}, ) assert response.status_code == 200 @@ -151,6 +151,7 @@ def test_update_assistant_invalid_vector_store( mock_verify_vector_ids, client: TestClient, db: Session, + user_api_key: AuthContext, ): """Test failure when updating assistant with invalid vector store IDs.""" mock_verify_vector_ids.side_effect = HTTPException( @@ -159,13 +160,12 @@ def test_update_assistant_invalid_vector_store( update_payload = {"vector_store_ids_add": ["vs_invalid"]} - assistant = get_assistant(db) - api_key = get_api_keys_by_project(db, assistant.project_id)[0] + assistant = get_assistant(db, project_id=user_api_key.project_id) response = client.patch( f"/api/v1/assistant/{assistant.assistant_id}", json=update_payload, - headers={"X-API-KEY": f"{api_key.key}"}, + headers={"X-API-KEY": f"{user_api_key.key}"}, ) assert response.status_code == 400 @@ -175,7 +175,7 @@ def test_update_assistant_invalid_vector_store( def test_update_assistant_not_found( client: TestClient, - user_api_key_header: dict, + user_api_key: AuthContext, ): """Test failure when updating a non-existent assistant.""" update_payload = {"name": "Updated Assistant"} @@ -185,7 +185,7 @@ def test_update_assistant_not_found( response = client.patch( f"/api/v1/assistant/{non_existent_id}", json=update_payload, - headers=user_api_key_header, + headers={"X-API-KEY": f"{user_api_key.key}"}, ) assert response.status_code == 404 @@ -196,14 +196,14 @@ def test_update_assistant_not_found( def test_get_assistant_success( client: TestClient, db: Session, + user_api_key: AuthContext, ): """Test successful retrieval of a single assistant.""" - assistant = get_assistant(db) - api_key = get_api_keys_by_project(db, assistant.project_id)[0] + assistant = get_assistant(db, project_id=user_api_key.project_id) response = client.get( f"/api/v1/assistant/{assistant.assistant_id}", - headers={"X-API-KEY": f"{api_key.key}"}, + headers={"X-API-KEY": f"{user_api_key.key}"}, ) assert response.status_code == 200 @@ -235,14 +235,14 @@ def test_get_assistant_not_found( def test_list_assistants_success( client: TestClient, db: Session, + user_api_key: AuthContext, ): """Test successful retrieval of assistants list.""" - assistant = get_assistant(db) - api_key = get_api_keys_by_project(db, assistant.project_id)[0] + assistant = get_assistant(db, project_id=user_api_key.project_id) response = client.get( "/api/v1/assistant/", - headers={"X-API-KEY": f"{api_key.key}"}, + headers={"X-API-KEY": f"{user_api_key.key}"}, ) assert response.status_code == 200 @@ -286,14 +286,14 @@ def test_list_assistants_invalid_pagination( def test_delete_assistant_success( client: TestClient, db: Session, + user_api_key: AuthContext, ): """Test successful soft deletion of an assistant.""" - assistant = get_assistant(db) - api_key = get_api_keys_by_project(db, assistant.project_id)[0] + assistant = get_assistant(db, project_id=user_api_key.project_id) response = client.delete( f"/api/v1/assistant/{assistant.assistant_id}", - headers={"X-API-KEY": f"{api_key.key}"}, + headers={"X-API-KEY": f"{user_api_key.key}"}, ) assert response.status_code == 200 diff --git a/backend/app/tests/api/routes/test_creds.py b/backend/app/tests/api/routes/test_creds.py index 8625cf9e..7832bc7c 100644 --- a/backend/app/tests/api/routes/test_creds.py +++ b/backend/app/tests/api/routes/test_creds.py @@ -2,7 +2,6 @@ from fastapi.testclient import TestClient from sqlmodel import Session -from app.models import APIKeyPublic from app.core.config import settings from app.core.providers import Provider from app.models.credentials import Credential @@ -14,6 +13,7 @@ create_test_credential, test_credential_data, ) +from app.tests.utils.auth import AuthContext @pytest.fixture @@ -23,7 +23,7 @@ def create_test_credentials(db: Session): def test_set_credential( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): project_id = user_api_key.project_id org_id = user_api_key.organization_id @@ -65,7 +65,7 @@ def test_set_credential( def test_set_credentials_ignored_mismatched_ids( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): # Even if mismatched IDs are sent, route uses API key context # Ensure clean state for provider @@ -90,7 +90,7 @@ def test_set_credentials_ignored_mismatched_ids( def test_read_credentials_with_creds( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): # Ensure at least one credential exists for current project api_key_value = "sk-" + generate_random_string(10) @@ -120,7 +120,7 @@ def test_read_credentials_with_creds( def test_read_credentials_not_found( - client: TestClient, db: Session, user_api_key: APIKeyPublic + client: TestClient, db: Session, user_api_key: AuthContext ): # Delete all first to ensure none remain client.delete( @@ -136,7 +136,7 @@ def test_read_credentials_not_found( def test_read_provider_credential( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): # Ensure exists client.delete( @@ -168,7 +168,7 @@ def test_read_provider_credential( def test_read_provider_credential_not_found( - client: TestClient, db: Session, user_api_key: APIKeyPublic + client: TestClient, db: Session, user_api_key: AuthContext ): # Ensure none client.delete( @@ -185,7 +185,7 @@ def test_read_provider_credential_not_found( def test_update_credentials( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): # Ensure exists client.delete( @@ -230,7 +230,7 @@ def test_update_credentials( def test_update_credentials_not_found_for_provider( - client: TestClient, db: Session, user_api_key: APIKeyPublic + client: TestClient, db: Session, user_api_key: AuthContext ): # Ensure none exist client.delete( @@ -257,7 +257,7 @@ def test_update_credentials_not_found_for_provider( def test_delete_provider_credential( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): # Ensure exists client.delete( @@ -284,7 +284,7 @@ def test_delete_provider_credential( def test_delete_provider_credential_not_found( - client: TestClient, db: Session, user_api_key: APIKeyPublic + client: TestClient, db: Session, user_api_key: AuthContext ): # Ensure not exists client.delete( @@ -301,7 +301,7 @@ def test_delete_provider_credential_not_found( def test_delete_all_credentials( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): # Ensure exists client.delete( @@ -339,7 +339,7 @@ def test_delete_all_credentials( def test_delete_all_credentials_not_found( - client: TestClient, db: Session, user_api_key: APIKeyPublic + client: TestClient, db: Session, user_api_key: AuthContext ): # Ensure already deleted client.delete( @@ -357,7 +357,7 @@ def test_delete_all_credentials_not_found( def test_duplicate_credential_creation( client: TestClient, db: Session, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): credential = test_credential_data(db) # Ensure clean state for provider @@ -386,7 +386,7 @@ def test_duplicate_credential_creation( def test_multiple_provider_credentials( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): # Ensure clean state for current org/project client.delete( @@ -453,7 +453,7 @@ def test_multiple_provider_credentials( def test_credential_encryption( client: TestClient, db: Session, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): credential = test_credential_data(db) original_api_key = credential.credential[Provider.OPENAI.value]["api_key"] @@ -491,7 +491,7 @@ def test_credential_encryption( def test_credential_encryption_consistency( - client: TestClient, db: Session, user_api_key: APIKeyPublic + client: TestClient, db: Session, user_api_key: AuthContext ): credentials = test_credential_data(db) original_api_key = credentials.credential[Provider.OPENAI.value]["api_key"] diff --git a/backend/app/tests/api/routes/test_doc_transformation_job.py b/backend/app/tests/api/routes/test_doc_transformation_job.py index a808c430..d2f4510a 100644 --- a/backend/app/tests/api/routes/test_doc_transformation_job.py +++ b/backend/app/tests/api/routes/test_doc_transformation_job.py @@ -3,13 +3,14 @@ from app.core.config import settings from app.crud.doc_transformation_job import DocTransformationJobCrud -from app.models import APIKeyPublic, TransformationStatus +from app.models import TransformationStatus from app.tests.utils.document import DocumentStore +from app.tests.utils.auth import AuthContext class TestGetTransformationJob: def test_get_existing_job_success( - self, client: TestClient, db: Session, user_api_key: APIKeyPublic + self, client: TestClient, db: Session, user_api_key: AuthContext ): """Test successfully retrieving an existing transformation job.""" document = DocumentStore(db, user_api_key.project_id).put() @@ -31,7 +32,7 @@ def test_get_existing_job_success( assert data["data"]["transformed_document_id"] is None def test_get_nonexistent_job_404( - self, client: TestClient, db: Session, user_api_key: APIKeyPublic + self, client: TestClient, db: Session, user_api_key: AuthContext ): """Test getting a non-existent transformation job returns 404.""" fake_uuid = "00000000-0000-0000-0000-000000000001" @@ -44,7 +45,7 @@ def test_get_nonexistent_job_404( assert response.status_code == 404 def test_get_job_invalid_uuid_422( - self, client: TestClient, user_api_key: APIKeyPublic + self, client: TestClient, user_api_key: AuthContext ): """Test getting a job with invalid UUID format returns 422.""" invalid_uuid = "not-a-uuid" @@ -60,8 +61,8 @@ def test_get_job_different_project_404( self, client: TestClient, db: Session, - user_api_key: APIKeyPublic, - superuser_api_key: APIKeyPublic, + user_api_key: AuthContext, + superuser_api_key: AuthContext, ): """Test that jobs from different projects are not accessible.""" store = DocumentStore(db, user_api_key.project_id) @@ -78,7 +79,7 @@ def test_get_job_different_project_404( assert response.status_code == 404 def test_get_completed_job_with_result( - self, client: TestClient, db: Session, user_api_key: APIKeyPublic + self, client: TestClient, db: Session, user_api_key: AuthContext ): """Test getting a completed job with transformation result.""" store = DocumentStore(db, user_api_key.project_id) @@ -105,7 +106,7 @@ def test_get_completed_job_with_result( assert data["data"]["transformed_document_id"] == str(transformed_document.id) def test_get_failed_job_with_error( - self, client: TestClient, db: Session, user_api_key: APIKeyPublic + self, client: TestClient, db: Session, user_api_key: AuthContext ): """Test getting a failed job with error message.""" store = DocumentStore(db, user_api_key.project_id) @@ -130,7 +131,7 @@ def test_get_failed_job_with_error( class TestGetMultipleTransformationJobs: def test_get_multiple_jobs_success( - self, client: TestClient, db: Session, user_api_key: APIKeyPublic + self, client: TestClient, db: Session, user_api_key: AuthContext ): """Test successfully retrieving multiple transformation jobs.""" store = DocumentStore(db, user_api_key.project_id) @@ -155,7 +156,7 @@ def test_get_multiple_jobs_success( assert returned_ids == expected_ids def test_get_mixed_existing_nonexisting_jobs( - self, client: TestClient, db: Session, user_api_key: APIKeyPublic + self, client: TestClient, db: Session, user_api_key: AuthContext ): """Test retrieving a mix of existing and non-existing jobs.""" store = DocumentStore(db, user_api_key.project_id) @@ -180,7 +181,7 @@ def test_get_mixed_existing_nonexisting_jobs( assert data["data"]["jobs_not_found"][0] == fake_uuid def test_get_jobs_with_empty_string( - self, client: TestClient, user_api_key: APIKeyPublic + self, client: TestClient, user_api_key: AuthContext ): """Test retrieving jobs with empty job_ids parameter.""" response = client.get( @@ -191,7 +192,7 @@ def test_get_jobs_with_empty_string( assert response.status_code == 422 def test_get_jobs_with_whitespace_only( - self, client: TestClient, user_api_key: APIKeyPublic + self, client: TestClient, user_api_key: AuthContext ): """Test retrieving jobs with whitespace-only job_ids.""" response = client.get( @@ -202,7 +203,7 @@ def test_get_jobs_with_whitespace_only( assert response.status_code == 422 def test_get_jobs_invalid_uuid_format_422( - self, client: TestClient, user_api_key: APIKeyPublic + self, client: TestClient, user_api_key: AuthContext ): """Test that invalid UUID format returns 422.""" invalid_uuid = "not-a-uuid" @@ -217,7 +218,7 @@ def test_get_jobs_invalid_uuid_format_422( assert "Input should be a valid UUID" in data["error"] def test_get_jobs_mixed_valid_invalid_uuid_422( - self, client: TestClient, db: Session, user_api_key: APIKeyPublic + self, client: TestClient, db: Session, user_api_key: AuthContext ): """Test that mixed valid/invalid UUIDs returns 422.""" store = DocumentStore(db, user_api_key.project_id) @@ -238,7 +239,7 @@ def test_get_jobs_mixed_valid_invalid_uuid_422( assert "job_ids" in data["error"] def test_get_jobs_missing_parameter_422( - self, client: TestClient, user_api_key: APIKeyPublic + self, client: TestClient, user_api_key: AuthContext ): """Test that missing job_ids parameter returns empty results.""" response = client.get( @@ -252,8 +253,8 @@ def test_get_jobs_different_project_not_found( self, client: TestClient, db: Session, - user_api_key: APIKeyPublic, - superuser_api_key: APIKeyPublic, + user_api_key: AuthContext, + superuser_api_key: AuthContext, ): """Test that jobs from different projects are not returned.""" store = DocumentStore(db, user_api_key.project_id) @@ -274,7 +275,7 @@ def test_get_jobs_different_project_not_found( assert data["data"]["jobs_not_found"][0] == str(job.id) def test_get_jobs_with_various_statuses( - self, client: TestClient, db: Session, user_api_key: APIKeyPublic + self, client: TestClient, db: Session, user_api_key: AuthContext ): """Test retrieving jobs with different statuses.""" store = DocumentStore(db, user_api_key.project_id) diff --git a/backend/app/tests/api/routes/test_openai_conversation.py b/backend/app/tests/api/routes/test_openai_conversation.py index dafc569f..733ce17e 100644 --- a/backend/app/tests/api/routes/test_openai_conversation.py +++ b/backend/app/tests/api/routes/test_openai_conversation.py @@ -2,15 +2,15 @@ from fastapi.testclient import TestClient from app.crud.openai_conversation import create_conversation -from app.models import APIKeyPublic from app.models import OpenAIConversationCreate from app.tests.utils.openai import generate_openai_id +from app.tests.utils.auth import AuthContext def test_get_conversation_success( client: TestClient, db: Session, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): """Test successful conversation retrieval.""" @@ -45,7 +45,7 @@ def test_get_conversation_success( def test_get_conversation_not_found( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): """Test conversation retrieval with non-existent ID.""" response = client.get( @@ -61,7 +61,7 @@ def test_get_conversation_not_found( def test_get_conversation_by_response_id_success( client: TestClient, db: Session, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): """Test successful conversation retrieval by response ID.""" response_id = generate_openai_id("resp_", 40) @@ -96,7 +96,7 @@ def test_get_conversation_by_response_id_success( def test_get_conversation_by_response_id_not_found( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): """Test conversation retrieval with non-existent response ID.""" response = client.get( @@ -112,7 +112,7 @@ def test_get_conversation_by_response_id_not_found( def test_get_conversation_by_ancestor_id_success( client: TestClient, db: Session, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): """Test successful conversation retrieval by ancestor ID.""" ancestor_response_id = generate_openai_id("resp_", 40) @@ -147,7 +147,7 @@ def test_get_conversation_by_ancestor_id_success( def test_get_conversation_by_ancestor_id_not_found( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): """Test conversation retrieval with non-existent ancestor ID.""" response = client.get( @@ -163,7 +163,7 @@ def test_get_conversation_by_ancestor_id_not_found( def test_list_conversations_success( client: TestClient, db: Session, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): """Test successful conversation listing.""" conversation_data = OpenAIConversationCreate( @@ -199,7 +199,7 @@ def test_list_conversations_success( def test_list_conversations_with_pagination( client: TestClient, db: Session, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): """Test conversation listing with pagination.""" # Create multiple conversations @@ -262,7 +262,7 @@ def test_list_conversations_with_pagination( def test_list_conversations_pagination_metadata( client: TestClient, db: Session, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): """Test conversation listing pagination metadata.""" # Create 5 conversations @@ -320,7 +320,7 @@ def test_list_conversations_pagination_metadata( def test_list_conversations_default_pagination( client: TestClient, db: Session, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): """Test conversation listing with default pagination parameters.""" # Create a conversation @@ -360,7 +360,7 @@ def test_list_conversations_default_pagination( def test_list_conversations_edge_cases( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): """Test conversation listing edge cases for pagination.""" # Test with skip larger than total @@ -397,7 +397,7 @@ def test_list_conversations_edge_cases( def test_list_conversations_invalid_pagination( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): """Test conversation listing with invalid pagination parameters.""" response = client.get( @@ -411,7 +411,7 @@ def test_list_conversations_invalid_pagination( def test_delete_conversation_success( client: TestClient, db: Session, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): """Test successful conversation deletion.""" conversation_data = OpenAIConversationCreate( @@ -453,7 +453,7 @@ def test_delete_conversation_success( def test_delete_conversation_not_found( client: TestClient, - user_api_key: APIKeyPublic, + user_api_key: AuthContext, ): """Test conversation deletion with non-existent ID.""" response = client.delete( diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index 954059e7..c299a909 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -14,9 +14,13 @@ from app.core.db import engine from app.api.deps import get_db from app.main import app -from app.models import APIKeyPublic from app.tests.utils.user import authentication_token_from_email -from app.tests.utils.utils import get_superuser_token_headers, get_api_key_by_email +from app.tests.utils.utils import get_superuser_token_headers +from app.tests.utils.auth import ( + get_superuser_auth_context, + get_user_auth_context, + AuthContext, +) from app.seed_data.seed_data import seed_database @@ -67,25 +71,25 @@ def normal_user_token_headers(client: TestClient, db: Session) -> dict[str, str] ) -@pytest.fixture(scope="function") +@pytest.fixture def superuser_api_key_header(db: Session) -> dict[str, str]: - api_key = get_api_key_by_email(db, settings.FIRST_SUPERUSER) - return {"X-API-KEY": api_key.key} + auth_ctx = get_superuser_auth_context(db) + return {"X-API-KEY": auth_ctx.key} -@pytest.fixture(scope="function") +@pytest.fixture def user_api_key_header(db: Session) -> dict[str, str]: - api_key = get_api_key_by_email(db, settings.EMAIL_TEST_USER) - return {"X-API-KEY": api_key.key} + auth_ctx = get_user_auth_context(db) + return {"X-API-KEY": auth_ctx.key} -@pytest.fixture(scope="function") -def superuser_api_key(db: Session) -> APIKeyPublic: - api_key = get_api_key_by_email(db, settings.FIRST_SUPERUSER) - return api_key +@pytest.fixture +def superuser_api_key(db: Session) -> AuthContext: + auth_ctx = get_superuser_auth_context(db) + return auth_ctx -@pytest.fixture(scope="function") -def user_api_key(db: Session) -> APIKeyPublic: - api_key = get_api_key_by_email(db, settings.EMAIL_TEST_USER) - return api_key +@pytest.fixture +def user_api_key(db: Session) -> AuthContext: + auth_ctx = get_user_auth_context(db) + return auth_ctx diff --git a/backend/app/tests/core/doctransformer/test_service/conftest.py b/backend/app/tests/core/doctransformer/test_service/conftest.py index 7f3aeda0..ccf7bc2c 100644 --- a/backend/app/tests/core/doctransformer/test_service/conftest.py +++ b/backend/app/tests/core/doctransformer/test_service/conftest.py @@ -16,7 +16,7 @@ from app.core.config import settings from app.models import Document, Project, UserProjectOrg from app.tests.utils.document import DocumentStore -from app.tests.utils.test_data import create_test_api_key +from app.tests.utils.auth import get_user_auth_context, AuthContext @pytest.fixture(scope="class") @@ -52,10 +52,10 @@ def fast_execute_job_func( @pytest.fixture -def current_user(db: Session) -> UserProjectOrg: +def current_user(db: Session, user_api_key: AuthContext) -> UserProjectOrg: """Create a test user for testing.""" - api_key = create_test_api_key(db) - user = db.get(User, api_key.user_id) + api_key = user_api_key + user = api_key.user return UserProjectOrg( **user.model_dump(), project_id=api_key.project_id, diff --git a/backend/app/tests/crud/test_api_key.py b/backend/app/tests/crud/test_api_key.py index 9f4281dc..e69de29b 100644 --- a/backend/app/tests/crud/test_api_key.py +++ b/backend/app/tests/crud/test_api_key.py @@ -1,117 +0,0 @@ -from sqlmodel import Session, select - -from app.crud import api_key as api_key_crud -from app.models import APIKey -from app.tests.utils.utils import get_non_existent_id -from app.tests.utils.user import create_random_user -from app.tests.utils.test_data import create_test_api_key, create_test_project - - -def test_create_api_key(db: Session) -> None: - user = create_random_user(db) - project = create_test_project(db) - - api_key = api_key_crud.create_api_key( - db, project.organization_id, user.id, project.id - ) - - assert api_key.key.startswith("ApiKey ") - assert len(api_key.key) > 32 - assert api_key.organization_id == project.organization_id - assert api_key.user_id == user.id - assert api_key.project_id == project.id - - -def test_get_api_key(db: Session) -> None: - api_key = create_test_api_key(db) - retrieved_key = api_key_crud.get_api_key(db, api_key.id) - - assert retrieved_key is not None - assert retrieved_key.id == api_key.id - assert retrieved_key.key.startswith("ApiKey ") - assert retrieved_key.project_id == api_key.project_id - - -def test_get_api_key_not_found(db: Session) -> None: - api_key_id = get_non_existent_id(db, APIKey) - result = api_key_crud.get_api_key(db, api_key_id) - assert result is None - - -def test_delete_api_key(db: Session) -> None: - api_key = create_test_api_key(db) - api_key_crud.delete_api_key(db, api_key.id) - - deleted_key = db.exec(select(APIKey).where(APIKey.id == api_key.id)).first() - - assert deleted_key is not None - assert deleted_key.is_deleted is True - assert deleted_key.deleted_at is not None - - -def test_get_api_key_by_value(db: Session) -> None: - api_key = create_test_api_key(db) - raw_key = api_key.key - - # Test retrieving the API key by its value - retrieved_key = api_key_crud.get_api_key_by_value(db, raw_key) - - assert retrieved_key is not None - assert retrieved_key.id == api_key.id - assert retrieved_key.organization_id == api_key.organization_id - assert retrieved_key.user_id == api_key.user_id - # The key should be in its original format - assert retrieved_key.key == raw_key # Should be exactly the same key - assert retrieved_key.key.startswith("ApiKey ") - assert len(retrieved_key.key) > 32 - - -def test_get_api_key_by_project_user(db: Session) -> None: - user = create_random_user(db) - project = create_test_project(db) - - created_key = api_key_crud.create_api_key( - db, project.organization_id, user.id, project.id - ) - retrieved_key = api_key_crud.get_api_key_by_project_user(db, project.id, user.id) - - assert retrieved_key is not None - assert retrieved_key.id == created_key.id - assert retrieved_key.project_id == project.id - assert retrieved_key.key.startswith("ApiKey ") - - -def test_get_api_keys_by_project(db: Session) -> None: - user = create_random_user(db) - project = create_test_project(db) - - created_key = api_key_crud.create_api_key( - db, project.organization_id, user.id, project.id - ) - - retrieved_keys = api_key_crud.get_api_keys_by_project(db, project.id) - - assert retrieved_keys is not None - assert len(retrieved_keys) == 1 - retrieved_key = retrieved_keys[0] - - assert retrieved_key.id == created_key.id - assert retrieved_key.project_id == project.id - assert retrieved_key.key.startswith("ApiKey ") - - -def test_get_api_key_by_user_id(db: Session) -> None: - user = create_random_user(db) - project = create_test_project(db) - - created_key = api_key_crud.create_api_key( - db, project.organization_id, user.id, project.id - ) - - retrieved_key = api_key_crud.get_api_key_by_user_id(db, user.id) - - assert retrieved_key is not None - - assert retrieved_key.id == created_key.id - assert retrieved_key.user_id == user.id - assert retrieved_key.key.startswith("ApiKey ") diff --git a/backend/app/tests/crud/test_onboarding.py b/backend/app/tests/crud/test_onboarding.py index 3ff64bbf..613669d1 100644 --- a/backend/app/tests/crud/test_onboarding.py +++ b/backend/app/tests/crud/test_onboarding.py @@ -229,7 +229,6 @@ def test_onboard_project_api_key_generation(db: Session) -> None: ) ).first() assert api_key_record is not None - assert api_key_record.key != response.api_key def test_onboard_project_response_data_integrity(db: Session) -> None: diff --git a/backend/app/tests/utils/auth.py b/backend/app/tests/utils/auth.py new file mode 100644 index 00000000..245b1ce7 --- /dev/null +++ b/backend/app/tests/utils/auth.py @@ -0,0 +1,137 @@ +from uuid import UUID + +from sqlmodel import Session, select, SQLModel + +from app.models import User, Organization, Project, APIKey +from app.core.config import settings + + +class AuthContext(SQLModel): + """Authentication context for testing""" + user_id: int + project_id: int + organization_id: int + key: str # The full unencrypted API key with "ApiKey " prefix + api_key_id: UUID # The UUID of the API key record + + # Complete nested objects + user: User + project: Project + organization: Organization + api_key: APIKey + + +def get_auth_context( + session: Session, + user_email: str, + project_name: str, + raw_key: str, + user_type: str = "User" +) -> AuthContext: + """ + Helper function to get authentication context from seeded data. + + Args: + session: Database session + user_email: Email of the user + project_name: Name of the project + raw_key: The full unencrypted API key with "ApiKey " prefix + user_type: Type of user for error messages (e.g., "Superuser", "User") + + Returns: + AuthContext with all IDs and keys from seeded data + + Raises: + ValueError: If the required data is not found in the database + """ + # Get user from seed data + user = session.exec(select(User).where(User.email == user_email)).first() + if not user: + raise ValueError(f"{user_type} with email {user_email} not found. Ensure seed data is loaded.") + + # Get project from seed data + project = session.exec(select(Project).where(Project.name == project_name)).first() + if not project: + raise ValueError(f"Project {project_name} not found. Ensure seed data is loaded.") + + # Get organization + org = session.exec(select(Organization).where(Organization.id == project.organization_id)).first() + if not org: + raise ValueError(f"Organization for project {project_name} not found.") + + # Get API key for this user and project + api_key = session.exec( + select(APIKey) + .where(APIKey.user_id == user.id) + .where(APIKey.project_id == project.id) + .where(APIKey.is_deleted == False) + ).first() + if not api_key: + raise ValueError(f"API key for {user_type.lower()} and project {project_name} not found.") + + # Return complete auth context + return AuthContext( + user_id=user.id, + project_id=project.id, + organization_id=org.id, + key=raw_key, + api_key_id=api_key.id, + user=user, + project=project, + organization=org, + api_key=api_key, + ) + + +def get_superuser_auth_context(session: Session) -> AuthContext: + """ + Get authentication context for superuser from seeded data. + + Uses SUPERUSER_EMAIL with Glific project based on seed_data.json: + - User: {{SUPERUSER_EMAIL}} (is_superuser: true) + - Project: Glific + - API Key: ApiKey No3x47A5qoIGhm0kVKjQ77dhCqEdWRIQZlEPzzzh7i8 + + Args: + session: Database session + + Returns: + AuthContext with all IDs and keys from seeded data + + Raises: + ValueError: If the required data is not found in the database + """ + return get_auth_context( + session=session, + user_email=settings.FIRST_SUPERUSER, + project_name="Glific", + raw_key="ApiKey No3x47A5qoIGhm0kVKjQ77dhCqEdWRIQZlEPzzzh7i8", + user_type="Superuser" + ) + + +def get_user_auth_context(session: Session) -> AuthContext: + """ + Get authentication context for normal user from seeded data. + + Uses ADMIN_EMAIL with Dalgo project based on seed_data.json: + - User: {{ADMIN_EMAIL}} (is_superuser: false) + - Project: Dalgo + - API Key: ApiKey Px8y47B6roJHin1lWLkR88eiDrFdXSJRZmFQazzai8j9 + + Args: + session: Database session + + Returns: + AuthContext with all IDs and keys from seeded data + + Raises: + ValueError: If the required data is not found in the database + """ + return get_auth_context( + session=session, + user_email=settings.EMAIL_TEST_USER, + project_name="Dalgo", + raw_key="ApiKey Px8y47B6roJHin1lWLkR88eiDrFdXSJRZmFQazzai8j", + user_type="User" + ) diff --git a/backend/app/tests/utils/collection.py b/backend/app/tests/utils/collection.py index b2d3ae94..111e8243 100644 --- a/backend/app/tests/utils/collection.py +++ b/backend/app/tests/utils/collection.py @@ -8,7 +8,7 @@ from app.models import Collection, Organization, Project from app.tests.utils.utils import get_user_id_by_email from app.tests.utils.test_data import create_test_project -from app.crud import create_api_key +from app.crud import APIKeyCrud class constants: @@ -29,12 +29,7 @@ def get_collection(db: Session, client=None, owner_id: int = None) -> Collection project = create_test_project(db) # Step 2: Create API key for user with valid foreign keys - create_api_key( - db, - organization_id=project.organization_id, - user_id=owner_id, - project_id=project.id, - ) + APIKeyCrud(session=db, project_id=project.id).create(user_id=owner_id) if client is None: client = OpenAI(api_key="test_api_key") diff --git a/backend/app/tests/utils/document.py b/backend/app/tests/utils/document.py index 5db59658..f6368f81 100644 --- a/backend/app/tests/utils/document.py +++ b/backend/app/tests/utils/document.py @@ -13,8 +13,9 @@ from app.core.config import settings from app.crud.project import get_project_by_id -from app.models import APIKeyPublic, Document, DocumentPublic, Project +from app.models import Document, DocumentPublic, Project from app.utils import APIResponse +from app.tests.utils.auth import AuthContext from .utils import SequentialUuidGenerator @@ -112,7 +113,7 @@ def append(self, doc: Document, suffix: str = None): @dataclass class WebCrawler: client: TestClient - user_api_key: APIKeyPublic + user_api_key: AuthContext def get(self, route: Route): return self.client.get( @@ -165,5 +166,5 @@ def to_public_dict(self) -> dict: @pytest.fixture -def crawler(client: TestClient, user_api_key: APIKeyPublic): +def crawler(client: TestClient, user_api_key: AuthContext): return WebCrawler(client, user_api_key=user_api_key) diff --git a/backend/app/tests/utils/test_data.py b/backend/app/tests/utils/test_data.py index 616904f2..79d22351 100644 --- a/backend/app/tests/utils/test_data.py +++ b/backend/app/tests/utils/test_data.py @@ -17,10 +17,10 @@ from app.crud import ( create_organization, create_project, - create_api_key, set_creds_for_org, create_fine_tuning_job, create_model_evaluation, + APIKeyCrud, ) from app.core.providers import Provider from app.tests.utils.user import create_random_user @@ -30,6 +30,7 @@ get_document, get_project, ) +from app.tests.utils.auth import AuthContext, get_auth_context def create_test_organization(db: Session) -> Organization: @@ -62,23 +63,6 @@ def create_test_project(db: Session) -> Project: return create_project(session=db, project_create=project_in) -def create_test_api_key(db: Session) -> APIKey: - """ - Creates and returns an API key for a test project and test user. - - Persists a test user, organization, project, and API key to the database - """ - project = create_test_project(db) - user = create_random_user(db) - api_key = create_api_key( - db, - organization_id=project.organization_id, - user_id=user.id, - project_id=project.id, - ) - return api_key - - def test_credential_data(db: Session) -> CredsCreate: """ Returns credential data for a test project in the form of a CredsCreate schema. diff --git a/backend/app/tests/utils/utils.py b/backend/app/tests/utils/utils.py index 362b7434..0ea1e1d5 100644 --- a/backend/app/tests/utils/utils.py +++ b/backend/app/tests/utils/utils.py @@ -11,8 +11,7 @@ from app.core.config import settings from app.crud.user import get_user_by_email -from app.crud.api_key import get_api_key_by_value, get_api_key_by_user_id -from app.models import APIKeyPublic, Project, Assistant, Organization, Document +from app.models import Project, Assistant, Organization, Document T = TypeVar("T") @@ -42,26 +41,11 @@ def get_superuser_token_headers(client: TestClient) -> dict[str, str]: return headers -def get_api_key_by_email(db: Session, email: EmailStr) -> APIKeyPublic: - user = get_user_by_email(session=db, email=email) - api_key = get_api_key_by_user_id(db, user_id=user.id) - - return api_key - - def get_user_id_by_email(db: Session) -> int: user = get_user_by_email(session=db, email=settings.EMAIL_TEST_USER) return user.id -def get_user_from_api_key(db: Session, api_key_headers: dict[str, str]) -> APIKeyPublic: - key_value = api_key_headers["X-API-KEY"] - api_key = get_api_key_by_value(db, api_key_value=key_value) - if api_key is None: - raise ValueError("Invalid API Key") - return api_key - - def get_non_existent_id(session: Session, model: Type[T]) -> int: result = session.exec(select(model.id).order_by(model.id.desc())).first() return (result or 0) + 1 @@ -89,22 +73,24 @@ def get_project(session: Session, name: str | None = None) -> Project: return project -def get_assistant(session: Session, name: str | None = None) -> Assistant: +def get_assistant( + session: Session, project_id: int | None = None, name: str | None = None +) -> Assistant: """ Retrieve an active assistant from the database. If a assistant name is provided, fetch the active assistant with that name. If no name is provided, fetch any random assistant. """ + filters = [Assistant.is_deleted == False] + + if project_id is not None: + filters.append(Assistant.project_id == project_id) + if name: - statement = ( - select(Assistant) - .where(Assistant.name == name, Assistant.is_deleted == False) - .limit(1) - ) - else: - statement = select(Assistant).where(Assistant.is_deleted == False).limit(1) + filters.append(Assistant.name == name) + statement = select(Assistant).where(*filters).limit(1) assistant = session.exec(statement).first() if not assistant: From 14eb31ce391ca0b2443029e53cdead1c4d5d27af Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Mon, 6 Oct 2025 16:53:37 +0530 Subject: [PATCH 10/37] Fix API key route prefix and improve formatting in auth context functions --- backend/app/api/routes/api_keys.py | 2 +- backend/app/tests/utils/auth.py | 23 ++++++++++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/backend/app/api/routes/api_keys.py b/backend/app/api/routes/api_keys.py index b0198615..6eb594d6 100644 --- a/backend/app/api/routes/api_keys.py +++ b/backend/app/api/routes/api_keys.py @@ -7,7 +7,7 @@ from app.utils import APIResponse from app.api.permissions import Permission, require_permission -router = APIRouter(prefix="/api-keys", tags=["API Keys"]) +router = APIRouter(prefix="/apikeys", tags=["API Keys"]) @router.post( diff --git a/backend/app/tests/utils/auth.py b/backend/app/tests/utils/auth.py index 245b1ce7..9a419641 100644 --- a/backend/app/tests/utils/auth.py +++ b/backend/app/tests/utils/auth.py @@ -8,6 +8,7 @@ class AuthContext(SQLModel): """Authentication context for testing""" + user_id: int project_id: int organization_id: int @@ -26,7 +27,7 @@ def get_auth_context( user_email: str, project_name: str, raw_key: str, - user_type: str = "User" + user_type: str = "User", ) -> AuthContext: """ Helper function to get authentication context from seeded data. @@ -47,15 +48,21 @@ def get_auth_context( # Get user from seed data user = session.exec(select(User).where(User.email == user_email)).first() if not user: - raise ValueError(f"{user_type} with email {user_email} not found. Ensure seed data is loaded.") + raise ValueError( + f"{user_type} with email {user_email} not found. Ensure seed data is loaded." + ) # Get project from seed data project = session.exec(select(Project).where(Project.name == project_name)).first() if not project: - raise ValueError(f"Project {project_name} not found. Ensure seed data is loaded.") + raise ValueError( + f"Project {project_name} not found. Ensure seed data is loaded." + ) # Get organization - org = session.exec(select(Organization).where(Organization.id == project.organization_id)).first() + org = session.exec( + select(Organization).where(Organization.id == project.organization_id) + ).first() if not org: raise ValueError(f"Organization for project {project_name} not found.") @@ -67,7 +74,9 @@ def get_auth_context( .where(APIKey.is_deleted == False) ).first() if not api_key: - raise ValueError(f"API key for {user_type.lower()} and project {project_name} not found.") + raise ValueError( + f"API key for {user_type.lower()} and project {project_name} not found." + ) # Return complete auth context return AuthContext( @@ -106,7 +115,7 @@ def get_superuser_auth_context(session: Session) -> AuthContext: user_email=settings.FIRST_SUPERUSER, project_name="Glific", raw_key="ApiKey No3x47A5qoIGhm0kVKjQ77dhCqEdWRIQZlEPzzzh7i8", - user_type="Superuser" + user_type="Superuser", ) @@ -133,5 +142,5 @@ def get_user_auth_context(session: Session) -> AuthContext: user_email=settings.EMAIL_TEST_USER, project_name="Dalgo", raw_key="ApiKey Px8y47B6roJHin1lWLkR88eiDrFdXSJRZmFQazzai8j", - user_type="User" + user_type="User", ) From e9376cd8e70bb8ce58e6b60c6c629af51f497825 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Tue, 7 Oct 2025 08:30:36 +0530 Subject: [PATCH 11/37] move APIKeyManager to security module and clean up imports in CRUD operations --- backend/app/api/deps.py | 2 +- backend/app/core/security.py | 134 ++++++++++++++++++++++++++++++++++- backend/app/crud/api_key.py | 125 +------------------------------- 3 files changed, 134 insertions(+), 127 deletions(-) diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index 81b9ac9e..0f3cd445 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -12,9 +12,9 @@ from app.core import security from app.core.config import settings from app.core.db import engine +from app.core.security import api_key_manager from app.utils import APIResponse from app.crud.organization import validate_organization -from app.crud.api_key import api_key_manager from app.models import ( TokenPayload, User, diff --git a/backend/app/core/security.py b/backend/app/core/security.py index ace78c3a..dfa5768c 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -7,19 +7,25 @@ - Credentials encryption/decryption """ -from datetime import datetime, timedelta, timezone -from typing import Any import base64 import json +import logging +import secrets +from datetime import datetime, timedelta, timezone +from typing import Any, Tuple import jwt from cryptography.fernet import Fernet from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from passlib.context import CryptContext +from sqlmodel import Session, and_, select from app.core.config import settings + +logger = logging.getLogger(__name__) + # Password hashing configuration pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") @@ -179,3 +185,127 @@ def decrypt_credentials(encrypted_credentials: str) -> dict: return json.loads(decrypted_str) except Exception as e: raise ValueError(f"Failed to decrypt credentials: {e}") + + +class APIKeyManager: + """ + Handles API key generation and verification using secure hashing. + Supports Backwards compatibility with old key format. + """ + + # Configuration constants + PREFIX_NAME = "ApiKey " + PREFIX_BYTES = 16 # Generates 22 chars in urlsafe base64 + SECRET_BYTES = 32 # Generates 43 chars in urlsafe base64 + PREFIX_LENGTH = 22 + KEY_LENGTH = 65 # Total length: 22 (prefix) + 43 (secret) + HASH_ALGORITHM = "bcrypt" + + pwd_context = CryptContext(schemes=[HASH_ALGORITHM], deprecated="auto") + + @classmethod + def generate(cls) -> Tuple[str, str, str]: + """ + Generate a new API key with prefix and hashed value. + Ensures exact lengths: prefix=22 chars, secret=43 chars. + + Returns: + Tuple of (raw_key, key_prefix, key_hash) + """ + # Generate tokens and ensure exact length + secret_length = cls.KEY_LENGTH - cls.PREFIX_LENGTH + key_prefix = secrets.token_urlsafe(cls.PREFIX_BYTES)[: cls.PREFIX_LENGTH].ljust( + cls.PREFIX_LENGTH, "A" + ) + secret_key = secrets.token_urlsafe(cls.SECRET_BYTES)[:secret_length].ljust( + secret_length, "A" + ) + + # Construct raw key: "ApiKey {prefix}{secret}" + raw_key = f"{cls.PREFIX_NAME}{key_prefix}{secret_key}" + + key_hash = cls.pwd_context.hash(secret_key) + + return raw_key, key_prefix, key_hash + + @classmethod + def _extract_key_parts(cls, raw_key: str) -> Tuple[str, str] | None: + """ + Extract prefix and secret from an API key based on its format. + + Supports: + - New format: "ApiKey {22-char-prefix}{43-char-secret}" + - Old format: "ApiKey {12-char-prefix}{31-char-secret}" + + Returns: + Tuple[str, str] -> (key_prefix, secret_to_verify) + or None if invalid + """ + if not raw_key.startswith(cls.PREFIX_NAME): + return None + + key_part = raw_key[len(cls.PREFIX_NAME) :] + + if len(key_part) == cls.KEY_LENGTH: + key_prefix = key_part[: cls.PREFIX_LENGTH] + secret_key = key_part[cls.PREFIX_LENGTH :] + return key_prefix, secret_key + + old_key_length = 43 + old_prefix_length = 12 + if len(key_part) == old_key_length: + key_prefix = key_part[:old_prefix_length] + secret_key = key_part[old_prefix_length:] + return key_prefix, secret_key + + # Invalid format + return None + + @classmethod + def verify(cls, session: Session, raw_key: str): + """ + Verify an API key by checking its prefix and hashed value. + Supports both old (43 chars) and new ("ApiKey " + 65 chars) formats. + + Args: + session: Database session + raw_key: The raw API key to verify + + Returns: + The APIKey record if valid, None otherwise + """ + from app.models import APIKey + + try: + key_parts = cls._extract_key_parts(raw_key) + + if not key_parts: + return None + + key_prefix, secret = key_parts + + statement = select(APIKey).where( + and_( + APIKey.key_prefix == key_prefix, + APIKey.is_deleted.is_(False), + ) + ) + api_key_record = session.exec(statement).one_or_none() + + if not api_key_record: + return None + + if cls.pwd_context.verify(secret, api_key_record.key_hash): + return api_key_record + + return None + + except Exception as e: + logger.error( + f"[APIKeyManager.verify] Error verifying API key: {str(e)}", + exc_info=True, + ) + return None + + +api_key_manager = APIKeyManager() diff --git a/backend/app/crud/api_key.py b/backend/app/crud/api_key.py index 6620d9d6..99186cd6 100644 --- a/backend/app/crud/api_key.py +++ b/backend/app/crud/api_key.py @@ -1,141 +1,18 @@ import logging -import secrets from uuid import UUID from typing import Tuple from sqlmodel import Session, select, and_ from fastapi import HTTPException -from passlib.context import CryptContext from app.models import APIKey from app.crud import get_project_by_id from app.core.util import now +from app.core.security import api_key_manager logger = logging.getLogger(__name__) -class APIKeyManager: - """ - Handles API key generation and verification using secure hashing. - Supports Backwards compatibility with old key format. - """ - - # Configuration constants - PREFIX_NAME = "ApiKey " - PREFIX_BYTES = 16 # Generates 22 chars in urlsafe base64 - SECRET_BYTES = 32 # Generates 43 chars in urlsafe base64 - PREFIX_LENGTH = 22 - KEY_LENGTH = 65 # Total length: 22 (prefix) + 43 (secret) - HASH_ALGORITHM = "bcrypt" - - pwd_context = CryptContext(schemes=[HASH_ALGORITHM], deprecated="auto") - - @classmethod - def generate(cls) -> Tuple[str, str, str]: - """ - Generate a new API key with prefix and hashed value. - Ensures exact lengths: prefix=22 chars, secret=43 chars. - - Returns: - Tuple of (raw_key, key_prefix, key_hash) - """ - # Generate tokens and ensure exact length - secret_length = cls.KEY_LENGTH - cls.PREFIX_LENGTH - key_prefix = secrets.token_urlsafe(cls.PREFIX_BYTES)[: cls.PREFIX_LENGTH].ljust( - cls.PREFIX_LENGTH, "A" - ) - secret_key = secrets.token_urlsafe(cls.SECRET_BYTES)[:secret_length].ljust( - secret_length, "A" - ) - - # Construct raw key: "ApiKey {prefix}{secret}" - raw_key = f"{cls.PREFIX_NAME}{key_prefix}{secret_key}" - - key_hash = cls.pwd_context.hash(secret_key) - - return raw_key, key_prefix, key_hash - - @classmethod - def _extract_key_parts(cls, raw_key: str) -> Tuple[str, str] | None: - """ - Extract prefix and secret from an API key based on its format. - - Supports: - - New format: "ApiKey {22-char-prefix}{43-char-secret}" - - Old format: "ApiKey {12-char-prefix}{31-char-secret}" - - Returns: - Tuple[str, str] -> (key_prefix, secret_to_verify) - or None if invalid - """ - if not raw_key.startswith(cls.PREFIX_NAME): - return None - - key_part = raw_key[len(cls.PREFIX_NAME) :] - - if len(key_part) == cls.KEY_LENGTH: - key_prefix = key_part[: cls.PREFIX_LENGTH] - secret_key = key_part[cls.PREFIX_LENGTH :] - return key_prefix, secret_key - - old_key_length = 43 - old_prefix_length = 12 - if len(key_part) == old_key_length: - key_prefix = key_part[:old_prefix_length] - secret_key = key_part[old_prefix_length:] - return key_prefix, secret_key - - # Invalid format - return None - - @classmethod - def verify(cls, session: Session, raw_key: str) -> APIKey | None: - """ - Verify an API key by checking its prefix and hashed value. - Supports both old (43 chars) and new ("ApiKey " + 65 chars) formats. - - Args: - session: Database session - raw_key: The raw API key to verify - - Returns: - The APIKey record if valid, None otherwise - """ - try: - key_parts = cls._extract_key_parts(raw_key) - - if not key_parts: - return None - - key_prefix, secret = key_parts - - statement = select(APIKey).where( - and_( - APIKey.key_prefix == key_prefix, - APIKey.is_deleted.is_(False), - ) - ) - api_key_record = session.exec(statement).one_or_none() - - if not api_key_record: - return None - - if cls.pwd_context.verify(secret, api_key_record.key_hash): - return api_key_record - - return None - - except Exception as e: - logger.error( - f"[APIKeyManager.verify] Error verifying API key: {str(e)}", - exc_info=True, - ) - return None - - -api_key_manager = APIKeyManager() - - class APIKeyCrud: """ CRUD operations for API keys scoped to a project. From b3237c34099b7989a588c4e7a0b4f80aa713435d Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Tue, 7 Oct 2025 09:54:14 +0530 Subject: [PATCH 12/37] user context handling to use AuthContext, updating dependencies and verification logic for API key authentication --- backend/app/api/deps.py | 40 ++++++++++++++--------------- backend/app/api/permissions.py | 20 +++++++-------- backend/app/api/routes/api_keys.py | 8 +++--- backend/app/core/security.py | 41 ++++++++++++++++++++++-------- backend/app/models/__init__.py | 3 +-- backend/app/models/auth.py | 13 ++++++++++ backend/app/models/user.py | 6 ----- 7 files changed, 78 insertions(+), 53 deletions(-) diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index 0f3cd445..70ee5715 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -16,9 +16,9 @@ from app.utils import APIResponse from app.crud.organization import validate_organization from app.models import ( + AuthContext, TokenPayload, User, - UserContext, UserProjectOrg, UserOrganization, ProjectUser, @@ -155,29 +155,27 @@ def get_user_context( session: SessionDep, token: TokenDep, api_key: Annotated[str, Depends(api_key_header)], -) -> UserContext: +) -> AuthContext: """ Verify valid authentication (API Key or JWT token) and return authenticated user context. - Returns UserContext with user info, project_id, and organization_id. + Returns AuthContext with user info, project_id, and organization_id. Authorization logic should be handled in routes. """ if api_key: - api_key_record = api_key_manager.verify(session, api_key) - if not api_key_record: + auth_context = api_key_manager.verify(session, api_key) + if not auth_context: raise HTTPException(status_code=401, detail="Invalid API Key") - user = session.get(User, api_key_record.user_id) - if not user: - raise HTTPException( - status_code=404, detail="User linked to API Key not found" - ) + if not auth_context.user.is_active: + raise HTTPException(status_code=400, detail="Inactive user") - user_context = UserContext( - **user.model_dump(), - project_id=api_key_record.project_id, - organization_id=api_key_record.organization_id, - ) - return user_context + if not auth_context.organization.is_active: + raise HTTPException(status_code=400, detail="Inactive Organization") + + if not auth_context.project.is_active: + raise HTTPException(status_code=400, detail="Inactive Project") + + return auth_context elif token: try: @@ -197,15 +195,17 @@ def get_user_context( if not user.is_active: raise HTTPException(status_code=400, detail="Inactive user") - user_context = UserContext(**user.model_dump()) - - return user_context + auth_context = AuthContext( + user_id=user.id, + user=user, + ) + return auth_context else: raise HTTPException(status_code=401, detail="Invalid Authorization format") -UserContextDep = Annotated[UserContext, Depends(get_user_context)] +AuthContextDep = Annotated[AuthContext, Depends(get_user_context)] def verify_user_project_organization( diff --git a/backend/app/api/permissions.py b/backend/app/api/permissions.py index 21850cf2..b5a99c52 100644 --- a/backend/app/api/permissions.py +++ b/backend/app/api/permissions.py @@ -3,8 +3,8 @@ from fastapi import Depends, HTTPException from sqlmodel import Session -from app.models import UserContext -from app.api.deps import UserContextDep, SessionDep +from app.models import AuthContext +from app.api.deps import AuthContextDep, SessionDep class Permission(str, Enum): @@ -16,12 +16,12 @@ class Permission(str, Enum): def has_permission( - user_context: UserContext, + auth_context: AuthContext, permission: Permission, session: Session | None = None, ) -> bool: """ - Check if the user_context has the specified permission. + Check if the auth_context has the specified permission. Args: user_context: The authenticated user context @@ -33,11 +33,11 @@ def has_permission( """ match permission: case Permission.SUPERUSER: - return user_context.is_superuser + return auth_context.user.is_superuser case Permission.REQUIRE_ORGANIZATION: - return user_context.organization_id is not None + return auth_context.organization_id is not None case Permission.REQUIRE_PROJECT: - return user_context.project_id is not None + return auth_context.project_id is not None case _: return False @@ -48,15 +48,15 @@ def require_permission(permission: Permission): Usage: @app.get("/endpoint", dependencies=[Depends(require_permission(Permission.REQUIRE_ORGANIZATION))]) - def endpoint(user_context: Annotated[UserContext, Depends(get_user_context)]): + def endpoint(auth_context: Annotated[AuthContext, Depends(get_user_context)]): pass """ def permission_checker( - user_context: UserContextDep, + auth_context: AuthContextDep, session: SessionDep, ): - if not has_permission(user_context, permission, session): + if not has_permission(auth_context, permission, session): error_messages = { Permission.SUPERUSER: "Insufficient permissions - require superuser access.", Permission.REQUIRE_ORGANIZATION: "Insufficient permissions - require organization access.", diff --git a/backend/app/api/routes/api_keys.py b/backend/app/api/routes/api_keys.py index 6eb594d6..970070ec 100644 --- a/backend/app/api/routes/api_keys.py +++ b/backend/app/api/routes/api_keys.py @@ -1,7 +1,7 @@ from uuid import UUID from fastapi import APIRouter, Depends, Query -from app.api.deps import SessionDep, UserContextDep +from app.api.deps import SessionDep, AuthContextDep from app.crud.api_key import APIKeyCrud from app.models import APIKeyPublic, APIKeyCreateResponse, Message from app.utils import APIResponse @@ -18,7 +18,7 @@ ) def create_api_key_route( project_id: int, - current_user: UserContextDep, + current_user: AuthContextDep, session: SessionDep, ): """ @@ -43,7 +43,7 @@ def create_api_key_route( dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], ) def list_api_keys_route( - current_user: UserContextDep, + current_user: AuthContextDep, session: SessionDep, skip: int = Query(0, ge=0, description="Number of records to skip"), limit: int = Query(100, ge=1, le=100, description="Maximum records to return"), @@ -67,7 +67,7 @@ def list_api_keys_route( ) def delete_api_key_route( key_id: UUID, - current_user: UserContextDep, + current_user: AuthContextDep, session: SessionDep, ): """ diff --git a/backend/app/core/security.py b/backend/app/core/security.py index dfa5768c..ddd24a16 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -21,6 +21,7 @@ from passlib.context import CryptContext from sqlmodel import Session, and_, select +from app.models import APIKey, User, Organization, Project, AuthContext from app.core.config import settings @@ -262,20 +263,20 @@ def _extract_key_parts(cls, raw_key: str) -> Tuple[str, str] | None: return None @classmethod - def verify(cls, session: Session, raw_key: str): + def verify(cls, session: Session, raw_key: str) -> AuthContext | None: """ Verify an API key by checking its prefix and hashed value. Supports both old (43 chars) and new ("ApiKey " + 65 chars) formats. + Eagerly loads User, Organization, and Project in a single query. + Args: session: Database session raw_key: The raw API key to verify Returns: - The APIKey record if valid, None otherwise + Tuple of (APIKey, User, Organization, Project) if valid, None otherwise """ - from app.models import APIKey - try: key_parts = cls._extract_key_parts(raw_key) @@ -284,19 +285,37 @@ def verify(cls, session: Session, raw_key: str): key_prefix, secret = key_parts - statement = select(APIKey).where( - and_( - APIKey.key_prefix == key_prefix, - APIKey.is_deleted.is_(False), + # Single query to fetch APIKey with User, Organization, and Project + statement = ( + select(APIKey, User, Organization, Project) + .where( + and_( + APIKey.key_prefix == key_prefix, + APIKey.is_deleted.is_(False), + ) ) + .join(User, User.id == APIKey.user_id) + .join(Organization, Organization.id == APIKey.organization_id) + .join(Project, Project.id == APIKey.project_id) ) - api_key_record = session.exec(statement).one_or_none() - if not api_key_record: + result = session.exec(statement).first() + + if not result: return None + api_key_record, user, organization, project = result + auth_context = AuthContext( + user_id=user.id, + project_id=project.id, + organization_id=organization.id, + user=user, + project=project, + organization=organization, + ) + # Verify the secret hash if cls.pwd_context.verify(secret, api_key_record.key_hash): - return api_key_record + return auth_context return None diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index a885bfb5..5ba1ffd0 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -1,6 +1,6 @@ from sqlmodel import SQLModel -from .auth import Token, TokenPayload +from .auth import AuthContext, Token, TokenPayload from .api_key import APIKey, APIKeyBase, APIKeyPublic, APIKeyCreateResponse @@ -48,7 +48,6 @@ NewPassword, User, UserCreate, - UserContext, UserOrganization, UserProjectOrg, UserPublic, diff --git a/backend/app/models/auth.py b/backend/app/models/auth.py index 7355c383..bfad680b 100644 --- a/backend/app/models/auth.py +++ b/backend/app/models/auth.py @@ -1,4 +1,7 @@ from sqlmodel import Field, SQLModel +from app.models.user import User +from app.models.organization import Organization +from app.models.project import Project # JSON payload containing access token @@ -10,3 +13,13 @@ class Token(SQLModel): # Contents of JWT token class TokenPayload(SQLModel): sub: str | None = None + + +class AuthContext(SQLModel): + user_id: int = Field(foreign_key="user.id") + project_id: int | None = None + organization_id: int | None = None + + user: User + organization: Organization | None = None + project: Project | None = None diff --git a/backend/app/models/user.py b/backend/app/models/user.py index f3523948..52512ab4 100644 --- a/backend/app/models/user.py +++ b/backend/app/models/user.py @@ -65,12 +65,6 @@ class UserProjectOrg(UserOrganization): project_id: int -class UserContext(UserBase): - id: int - project_id: int | None = None - organization_id: int | None = None - - # Properties to return via API, id is always required class UserPublic(UserBase): id: int From 029c00ed0fd72de0bcd4c30a52c64d859185e400 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Tue, 7 Oct 2025 09:58:07 +0530 Subject: [PATCH 13/37] rename AuthContext to TestAuthContext for consistency in test files --- .../collections/test_collection_info.py | 10 +++--- .../collections/test_create_collections.py | 4 +-- .../documents/test_route_document_upload.py | 4 +-- .../app/tests/api/routes/test_assistants.py | 20 +++++------ backend/app/tests/api/routes/test_creds.py | 34 +++++++++--------- .../api/routes/test_doc_transformation_job.py | 36 +++++++++---------- .../api/routes/test_openai_conversation.py | 30 ++++++++-------- backend/app/tests/conftest.py | 6 ++-- .../doctransformer/test_service/conftest.py | 4 +-- backend/app/tests/utils/auth.py | 16 ++++----- backend/app/tests/utils/document.py | 6 ++-- backend/app/tests/utils/test_data.py | 2 +- 12 files changed, 86 insertions(+), 86 deletions(-) diff --git a/backend/app/tests/api/routes/collections/test_collection_info.py b/backend/app/tests/api/routes/collections/test_collection_info.py index 66ebb01e..d970a283 100644 --- a/backend/app/tests/api/routes/collections/test_collection_info.py +++ b/backend/app/tests/api/routes/collections/test_collection_info.py @@ -5,12 +5,12 @@ from app.core.config import settings from app.models import Collection from app.models.collection import CollectionStatus -from app.tests.utils.auth import AuthContext +from app.tests.utils.auth import TestAuthContext def create_collection( db, - user_api_key: AuthContext, + user_api_key: TestAuthContext, status: CollectionStatus = CollectionStatus.processing, with_llm: bool = False, ): @@ -34,7 +34,7 @@ def create_collection( def test_collection_info_processing( - db: Session, client: TestClient, user_api_key: AuthContext + db: Session, client: TestClient, user_api_key: TestAuthContext ): headers = {"X-API-KEY": user_api_key.key} collection = create_collection(db, user_api_key, status=CollectionStatus.processing) @@ -54,7 +54,7 @@ def test_collection_info_processing( def test_collection_info_successful( - db: Session, client: TestClient, user_api_key: AuthContext + db: Session, client: TestClient, user_api_key: TestAuthContext ): headers = {"X-API-KEY": user_api_key.key} collection = create_collection( @@ -76,7 +76,7 @@ def test_collection_info_successful( def test_collection_info_failed( - db: Session, client: TestClient, user_api_key: AuthContext + db: Session, client: TestClient, user_api_key: TestAuthContext ): headers = {"X-API-KEY": user_api_key.key} collection = create_collection(db, user_api_key, status=CollectionStatus.failed) diff --git a/backend/app/tests/api/routes/collections/test_create_collections.py b/backend/app/tests/api/routes/collections/test_create_collections.py index ff1e00af..03d1c025 100644 --- a/backend/app/tests/api/routes/collections/test_create_collections.py +++ b/backend/app/tests/api/routes/collections/test_create_collections.py @@ -11,7 +11,7 @@ from app.crud.collection import CollectionCrud from app.models.collection import CollectionStatus from app.tests.utils.openai import get_mock_openai_client_with_vector_store -from app.tests.utils.auth import AuthContext +from app.tests.utils.auth import TestAuthContext @pytest.fixture(autouse=True) @@ -48,7 +48,7 @@ def test_create_collection_success( mock_get_openai_client, client: TestClient, db: Session, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): store = DocumentStore(db, project_id=user_api_key.project_id) documents = store.fill(self._n_documents) diff --git a/backend/app/tests/api/routes/documents/test_route_document_upload.py b/backend/app/tests/api/routes/documents/test_route_document_upload.py index de6d6dd1..320c1992 100644 --- a/backend/app/tests/api/routes/documents/test_route_document_upload.py +++ b/backend/app/tests/api/routes/documents/test_route_document_upload.py @@ -18,7 +18,7 @@ WebCrawler, httpx_to_standard, ) -from app.tests.utils.auth import AuthContext +from app.tests.utils.auth import TestAuthContext class WebUploader(WebCrawler): @@ -77,7 +77,7 @@ def route(): @pytest.fixture -def uploader(client: TestClient, user_api_key: AuthContext): +def uploader(client: TestClient, user_api_key: TestAuthContext): return WebUploader(client, user_api_key) diff --git a/backend/app/tests/api/routes/test_assistants.py b/backend/app/tests/api/routes/test_assistants.py index e5f73fe9..d4d2aadc 100644 --- a/backend/app/tests/api/routes/test_assistants.py +++ b/backend/app/tests/api/routes/test_assistants.py @@ -6,7 +6,7 @@ from unittest.mock import patch from app.tests.utils.openai import mock_openai_assistant from app.tests.utils.utils import get_assistant -from app.tests.utils.auth import AuthContext +from app.tests.utils.auth import TestAuthContext @pytest.fixture @@ -30,7 +30,7 @@ def assistant_id(): def test_ingest_assistant_success( mock_fetch_assistant, client: TestClient, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): """Test successful assistant ingestion from OpenAI.""" mock_assistant = mock_openai_assistant() @@ -53,7 +53,7 @@ def test_create_assistant_success( mock_verify_vector_ids, client: TestClient, assistant_create_payload: dict, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): """Test successful assistant creation with OpenAI vector store ID verification.""" @@ -92,7 +92,7 @@ def test_create_assistant_invalid_vector_store( mock_verify_vector_ids, client: TestClient, assistant_create_payload: dict, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): """Test failure when one or more vector store IDs are invalid.""" @@ -117,7 +117,7 @@ def test_create_assistant_invalid_vector_store( def test_update_assistant_success( client: TestClient, db: Session, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): """Test successful assistant update.""" update_payload = { @@ -151,7 +151,7 @@ def test_update_assistant_invalid_vector_store( mock_verify_vector_ids, client: TestClient, db: Session, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): """Test failure when updating assistant with invalid vector store IDs.""" mock_verify_vector_ids.side_effect = HTTPException( @@ -175,7 +175,7 @@ def test_update_assistant_invalid_vector_store( def test_update_assistant_not_found( client: TestClient, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): """Test failure when updating a non-existent assistant.""" update_payload = {"name": "Updated Assistant"} @@ -196,7 +196,7 @@ def test_update_assistant_not_found( def test_get_assistant_success( client: TestClient, db: Session, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): """Test successful retrieval of a single assistant.""" assistant = get_assistant(db, project_id=user_api_key.project_id) @@ -235,7 +235,7 @@ def test_get_assistant_not_found( def test_list_assistants_success( client: TestClient, db: Session, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): """Test successful retrieval of assistants list.""" assistant = get_assistant(db, project_id=user_api_key.project_id) @@ -286,7 +286,7 @@ def test_list_assistants_invalid_pagination( def test_delete_assistant_success( client: TestClient, db: Session, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): """Test successful soft deletion of an assistant.""" assistant = get_assistant(db, project_id=user_api_key.project_id) diff --git a/backend/app/tests/api/routes/test_creds.py b/backend/app/tests/api/routes/test_creds.py index 7832bc7c..a3d00fd1 100644 --- a/backend/app/tests/api/routes/test_creds.py +++ b/backend/app/tests/api/routes/test_creds.py @@ -13,7 +13,7 @@ create_test_credential, test_credential_data, ) -from app.tests.utils.auth import AuthContext +from app.tests.utils.auth import TestAuthContext @pytest.fixture @@ -23,7 +23,7 @@ def create_test_credentials(db: Session): def test_set_credential( client: TestClient, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): project_id = user_api_key.project_id org_id = user_api_key.organization_id @@ -65,7 +65,7 @@ def test_set_credential( def test_set_credentials_ignored_mismatched_ids( client: TestClient, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): # Even if mismatched IDs are sent, route uses API key context # Ensure clean state for provider @@ -90,7 +90,7 @@ def test_set_credentials_ignored_mismatched_ids( def test_read_credentials_with_creds( client: TestClient, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): # Ensure at least one credential exists for current project api_key_value = "sk-" + generate_random_string(10) @@ -120,7 +120,7 @@ def test_read_credentials_with_creds( def test_read_credentials_not_found( - client: TestClient, db: Session, user_api_key: AuthContext + client: TestClient, db: Session, user_api_key: TestAuthContext ): # Delete all first to ensure none remain client.delete( @@ -136,7 +136,7 @@ def test_read_credentials_not_found( def test_read_provider_credential( client: TestClient, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): # Ensure exists client.delete( @@ -168,7 +168,7 @@ def test_read_provider_credential( def test_read_provider_credential_not_found( - client: TestClient, db: Session, user_api_key: AuthContext + client: TestClient, db: Session, user_api_key: TestAuthContext ): # Ensure none client.delete( @@ -185,7 +185,7 @@ def test_read_provider_credential_not_found( def test_update_credentials( client: TestClient, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): # Ensure exists client.delete( @@ -230,7 +230,7 @@ def test_update_credentials( def test_update_credentials_not_found_for_provider( - client: TestClient, db: Session, user_api_key: AuthContext + client: TestClient, db: Session, user_api_key: TestAuthContext ): # Ensure none exist client.delete( @@ -257,7 +257,7 @@ def test_update_credentials_not_found_for_provider( def test_delete_provider_credential( client: TestClient, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): # Ensure exists client.delete( @@ -284,7 +284,7 @@ def test_delete_provider_credential( def test_delete_provider_credential_not_found( - client: TestClient, db: Session, user_api_key: AuthContext + client: TestClient, db: Session, user_api_key: TestAuthContext ): # Ensure not exists client.delete( @@ -301,7 +301,7 @@ def test_delete_provider_credential_not_found( def test_delete_all_credentials( client: TestClient, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): # Ensure exists client.delete( @@ -339,7 +339,7 @@ def test_delete_all_credentials( def test_delete_all_credentials_not_found( - client: TestClient, db: Session, user_api_key: AuthContext + client: TestClient, db: Session, user_api_key: TestAuthContext ): # Ensure already deleted client.delete( @@ -357,7 +357,7 @@ def test_delete_all_credentials_not_found( def test_duplicate_credential_creation( client: TestClient, db: Session, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): credential = test_credential_data(db) # Ensure clean state for provider @@ -386,7 +386,7 @@ def test_duplicate_credential_creation( def test_multiple_provider_credentials( client: TestClient, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): # Ensure clean state for current org/project client.delete( @@ -453,7 +453,7 @@ def test_multiple_provider_credentials( def test_credential_encryption( client: TestClient, db: Session, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): credential = test_credential_data(db) original_api_key = credential.credential[Provider.OPENAI.value]["api_key"] @@ -491,7 +491,7 @@ def test_credential_encryption( def test_credential_encryption_consistency( - client: TestClient, db: Session, user_api_key: AuthContext + client: TestClient, db: Session, user_api_key: TestAuthContext ): credentials = test_credential_data(db) original_api_key = credentials.credential[Provider.OPENAI.value]["api_key"] diff --git a/backend/app/tests/api/routes/test_doc_transformation_job.py b/backend/app/tests/api/routes/test_doc_transformation_job.py index d2f4510a..92460058 100644 --- a/backend/app/tests/api/routes/test_doc_transformation_job.py +++ b/backend/app/tests/api/routes/test_doc_transformation_job.py @@ -5,12 +5,12 @@ from app.crud.doc_transformation_job import DocTransformationJobCrud from app.models import TransformationStatus from app.tests.utils.document import DocumentStore -from app.tests.utils.auth import AuthContext +from app.tests.utils.auth import TestAuthContext class TestGetTransformationJob: def test_get_existing_job_success( - self, client: TestClient, db: Session, user_api_key: AuthContext + self, client: TestClient, db: Session, user_api_key: TestAuthContext ): """Test successfully retrieving an existing transformation job.""" document = DocumentStore(db, user_api_key.project_id).put() @@ -32,7 +32,7 @@ def test_get_existing_job_success( assert data["data"]["transformed_document_id"] is None def test_get_nonexistent_job_404( - self, client: TestClient, db: Session, user_api_key: AuthContext + self, client: TestClient, db: Session, user_api_key: TestAuthContext ): """Test getting a non-existent transformation job returns 404.""" fake_uuid = "00000000-0000-0000-0000-000000000001" @@ -45,7 +45,7 @@ def test_get_nonexistent_job_404( assert response.status_code == 404 def test_get_job_invalid_uuid_422( - self, client: TestClient, user_api_key: AuthContext + self, client: TestClient, user_api_key: TestAuthContext ): """Test getting a job with invalid UUID format returns 422.""" invalid_uuid = "not-a-uuid" @@ -61,8 +61,8 @@ def test_get_job_different_project_404( self, client: TestClient, db: Session, - user_api_key: AuthContext, - superuser_api_key: AuthContext, + user_api_key: TestAuthContext, + superuser_api_key: TestAuthContext, ): """Test that jobs from different projects are not accessible.""" store = DocumentStore(db, user_api_key.project_id) @@ -79,7 +79,7 @@ def test_get_job_different_project_404( assert response.status_code == 404 def test_get_completed_job_with_result( - self, client: TestClient, db: Session, user_api_key: AuthContext + self, client: TestClient, db: Session, user_api_key: TestAuthContext ): """Test getting a completed job with transformation result.""" store = DocumentStore(db, user_api_key.project_id) @@ -106,7 +106,7 @@ def test_get_completed_job_with_result( assert data["data"]["transformed_document_id"] == str(transformed_document.id) def test_get_failed_job_with_error( - self, client: TestClient, db: Session, user_api_key: AuthContext + self, client: TestClient, db: Session, user_api_key: TestAuthContext ): """Test getting a failed job with error message.""" store = DocumentStore(db, user_api_key.project_id) @@ -131,7 +131,7 @@ def test_get_failed_job_with_error( class TestGetMultipleTransformationJobs: def test_get_multiple_jobs_success( - self, client: TestClient, db: Session, user_api_key: AuthContext + self, client: TestClient, db: Session, user_api_key: TestAuthContext ): """Test successfully retrieving multiple transformation jobs.""" store = DocumentStore(db, user_api_key.project_id) @@ -156,7 +156,7 @@ def test_get_multiple_jobs_success( assert returned_ids == expected_ids def test_get_mixed_existing_nonexisting_jobs( - self, client: TestClient, db: Session, user_api_key: AuthContext + self, client: TestClient, db: Session, user_api_key: TestAuthContext ): """Test retrieving a mix of existing and non-existing jobs.""" store = DocumentStore(db, user_api_key.project_id) @@ -181,7 +181,7 @@ def test_get_mixed_existing_nonexisting_jobs( assert data["data"]["jobs_not_found"][0] == fake_uuid def test_get_jobs_with_empty_string( - self, client: TestClient, user_api_key: AuthContext + self, client: TestClient, user_api_key: TestAuthContext ): """Test retrieving jobs with empty job_ids parameter.""" response = client.get( @@ -192,7 +192,7 @@ def test_get_jobs_with_empty_string( assert response.status_code == 422 def test_get_jobs_with_whitespace_only( - self, client: TestClient, user_api_key: AuthContext + self, client: TestClient, user_api_key: TestAuthContext ): """Test retrieving jobs with whitespace-only job_ids.""" response = client.get( @@ -203,7 +203,7 @@ def test_get_jobs_with_whitespace_only( assert response.status_code == 422 def test_get_jobs_invalid_uuid_format_422( - self, client: TestClient, user_api_key: AuthContext + self, client: TestClient, user_api_key: TestAuthContext ): """Test that invalid UUID format returns 422.""" invalid_uuid = "not-a-uuid" @@ -218,7 +218,7 @@ def test_get_jobs_invalid_uuid_format_422( assert "Input should be a valid UUID" in data["error"] def test_get_jobs_mixed_valid_invalid_uuid_422( - self, client: TestClient, db: Session, user_api_key: AuthContext + self, client: TestClient, db: Session, user_api_key: TestAuthContext ): """Test that mixed valid/invalid UUIDs returns 422.""" store = DocumentStore(db, user_api_key.project_id) @@ -239,7 +239,7 @@ def test_get_jobs_mixed_valid_invalid_uuid_422( assert "job_ids" in data["error"] def test_get_jobs_missing_parameter_422( - self, client: TestClient, user_api_key: AuthContext + self, client: TestClient, user_api_key: TestAuthContext ): """Test that missing job_ids parameter returns empty results.""" response = client.get( @@ -253,8 +253,8 @@ def test_get_jobs_different_project_not_found( self, client: TestClient, db: Session, - user_api_key: AuthContext, - superuser_api_key: AuthContext, + user_api_key: TestAuthContext, + superuser_api_key: TestAuthContext, ): """Test that jobs from different projects are not returned.""" store = DocumentStore(db, user_api_key.project_id) @@ -275,7 +275,7 @@ def test_get_jobs_different_project_not_found( assert data["data"]["jobs_not_found"][0] == str(job.id) def test_get_jobs_with_various_statuses( - self, client: TestClient, db: Session, user_api_key: AuthContext + self, client: TestClient, db: Session, user_api_key: TestAuthContext ): """Test retrieving jobs with different statuses.""" store = DocumentStore(db, user_api_key.project_id) diff --git a/backend/app/tests/api/routes/test_openai_conversation.py b/backend/app/tests/api/routes/test_openai_conversation.py index 733ce17e..500a467b 100644 --- a/backend/app/tests/api/routes/test_openai_conversation.py +++ b/backend/app/tests/api/routes/test_openai_conversation.py @@ -4,13 +4,13 @@ from app.crud.openai_conversation import create_conversation from app.models import OpenAIConversationCreate from app.tests.utils.openai import generate_openai_id -from app.tests.utils.auth import AuthContext +from app.tests.utils.auth import TestAuthContext def test_get_conversation_success( client: TestClient, db: Session, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): """Test successful conversation retrieval.""" @@ -45,7 +45,7 @@ def test_get_conversation_success( def test_get_conversation_not_found( client: TestClient, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): """Test conversation retrieval with non-existent ID.""" response = client.get( @@ -61,7 +61,7 @@ def test_get_conversation_not_found( def test_get_conversation_by_response_id_success( client: TestClient, db: Session, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): """Test successful conversation retrieval by response ID.""" response_id = generate_openai_id("resp_", 40) @@ -96,7 +96,7 @@ def test_get_conversation_by_response_id_success( def test_get_conversation_by_response_id_not_found( client: TestClient, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): """Test conversation retrieval with non-existent response ID.""" response = client.get( @@ -112,7 +112,7 @@ def test_get_conversation_by_response_id_not_found( def test_get_conversation_by_ancestor_id_success( client: TestClient, db: Session, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): """Test successful conversation retrieval by ancestor ID.""" ancestor_response_id = generate_openai_id("resp_", 40) @@ -147,7 +147,7 @@ def test_get_conversation_by_ancestor_id_success( def test_get_conversation_by_ancestor_id_not_found( client: TestClient, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): """Test conversation retrieval with non-existent ancestor ID.""" response = client.get( @@ -163,7 +163,7 @@ def test_get_conversation_by_ancestor_id_not_found( def test_list_conversations_success( client: TestClient, db: Session, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): """Test successful conversation listing.""" conversation_data = OpenAIConversationCreate( @@ -199,7 +199,7 @@ def test_list_conversations_success( def test_list_conversations_with_pagination( client: TestClient, db: Session, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): """Test conversation listing with pagination.""" # Create multiple conversations @@ -262,7 +262,7 @@ def test_list_conversations_with_pagination( def test_list_conversations_pagination_metadata( client: TestClient, db: Session, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): """Test conversation listing pagination metadata.""" # Create 5 conversations @@ -320,7 +320,7 @@ def test_list_conversations_pagination_metadata( def test_list_conversations_default_pagination( client: TestClient, db: Session, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): """Test conversation listing with default pagination parameters.""" # Create a conversation @@ -360,7 +360,7 @@ def test_list_conversations_default_pagination( def test_list_conversations_edge_cases( client: TestClient, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): """Test conversation listing edge cases for pagination.""" # Test with skip larger than total @@ -397,7 +397,7 @@ def test_list_conversations_edge_cases( def test_list_conversations_invalid_pagination( client: TestClient, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): """Test conversation listing with invalid pagination parameters.""" response = client.get( @@ -411,7 +411,7 @@ def test_list_conversations_invalid_pagination( def test_delete_conversation_success( client: TestClient, db: Session, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): """Test successful conversation deletion.""" conversation_data = OpenAIConversationCreate( @@ -453,7 +453,7 @@ def test_delete_conversation_success( def test_delete_conversation_not_found( client: TestClient, - user_api_key: AuthContext, + user_api_key: TestAuthContext, ): """Test conversation deletion with non-existent ID.""" response = client.delete( diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index c299a909..c9e0597a 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -19,7 +19,7 @@ from app.tests.utils.auth import ( get_superuser_auth_context, get_user_auth_context, - AuthContext, + TestAuthContext, ) from app.seed_data.seed_data import seed_database @@ -84,12 +84,12 @@ def user_api_key_header(db: Session) -> dict[str, str]: @pytest.fixture -def superuser_api_key(db: Session) -> AuthContext: +def superuser_api_key(db: Session) -> TestAuthContext: auth_ctx = get_superuser_auth_context(db) return auth_ctx @pytest.fixture -def user_api_key(db: Session) -> AuthContext: +def user_api_key(db: Session) -> TestAuthContext: auth_ctx = get_user_auth_context(db) return auth_ctx diff --git a/backend/app/tests/core/doctransformer/test_service/conftest.py b/backend/app/tests/core/doctransformer/test_service/conftest.py index ccf7bc2c..6fe7aa76 100644 --- a/backend/app/tests/core/doctransformer/test_service/conftest.py +++ b/backend/app/tests/core/doctransformer/test_service/conftest.py @@ -16,7 +16,7 @@ from app.core.config import settings from app.models import Document, Project, UserProjectOrg from app.tests.utils.document import DocumentStore -from app.tests.utils.auth import get_user_auth_context, AuthContext +from app.tests.utils.auth import get_user_auth_context, TestAuthContext @pytest.fixture(scope="class") @@ -52,7 +52,7 @@ def fast_execute_job_func( @pytest.fixture -def current_user(db: Session, user_api_key: AuthContext) -> UserProjectOrg: +def current_user(db: Session, user_api_key: TestAuthContext) -> UserProjectOrg: """Create a test user for testing.""" api_key = user_api_key user = api_key.user diff --git a/backend/app/tests/utils/auth.py b/backend/app/tests/utils/auth.py index 9a419641..c314c67e 100644 --- a/backend/app/tests/utils/auth.py +++ b/backend/app/tests/utils/auth.py @@ -6,7 +6,7 @@ from app.core.config import settings -class AuthContext(SQLModel): +class TestAuthContext(SQLModel): """Authentication context for testing""" user_id: int @@ -28,7 +28,7 @@ def get_auth_context( project_name: str, raw_key: str, user_type: str = "User", -) -> AuthContext: +) -> TestAuthContext: """ Helper function to get authentication context from seeded data. @@ -40,7 +40,7 @@ def get_auth_context( user_type: Type of user for error messages (e.g., "Superuser", "User") Returns: - AuthContext with all IDs and keys from seeded data + TestAuthContext with all IDs and keys from seeded data Raises: ValueError: If the required data is not found in the database @@ -79,7 +79,7 @@ def get_auth_context( ) # Return complete auth context - return AuthContext( + return TestAuthContext( user_id=user.id, project_id=project.id, organization_id=org.id, @@ -92,7 +92,7 @@ def get_auth_context( ) -def get_superuser_auth_context(session: Session) -> AuthContext: +def get_superuser_auth_context(session: Session) -> TestAuthContext: """ Get authentication context for superuser from seeded data. @@ -105,7 +105,7 @@ def get_superuser_auth_context(session: Session) -> AuthContext: session: Database session Returns: - AuthContext with all IDs and keys from seeded data + TestAuthContext with all IDs and keys from seeded data Raises: ValueError: If the required data is not found in the database @@ -119,7 +119,7 @@ def get_superuser_auth_context(session: Session) -> AuthContext: ) -def get_user_auth_context(session: Session) -> AuthContext: +def get_user_auth_context(session: Session) -> TestAuthContext: """ Get authentication context for normal user from seeded data. @@ -132,7 +132,7 @@ def get_user_auth_context(session: Session) -> AuthContext: session: Database session Returns: - AuthContext with all IDs and keys from seeded data + TestAuthContext with all IDs and keys from seeded data Raises: ValueError: If the required data is not found in the database diff --git a/backend/app/tests/utils/document.py b/backend/app/tests/utils/document.py index f6368f81..dddb1c2a 100644 --- a/backend/app/tests/utils/document.py +++ b/backend/app/tests/utils/document.py @@ -15,7 +15,7 @@ from app.crud.project import get_project_by_id from app.models import Document, DocumentPublic, Project from app.utils import APIResponse -from app.tests.utils.auth import AuthContext +from app.tests.utils.auth import TestAuthContext from .utils import SequentialUuidGenerator @@ -113,7 +113,7 @@ def append(self, doc: Document, suffix: str = None): @dataclass class WebCrawler: client: TestClient - user_api_key: AuthContext + user_api_key: TestAuthContext def get(self, route: Route): return self.client.get( @@ -166,5 +166,5 @@ def to_public_dict(self) -> dict: @pytest.fixture -def crawler(client: TestClient, user_api_key: AuthContext): +def crawler(client: TestClient, user_api_key: TestAuthContext): return WebCrawler(client, user_api_key=user_api_key) diff --git a/backend/app/tests/utils/test_data.py b/backend/app/tests/utils/test_data.py index 79d22351..3acabf63 100644 --- a/backend/app/tests/utils/test_data.py +++ b/backend/app/tests/utils/test_data.py @@ -30,7 +30,7 @@ get_document, get_project, ) -from app.tests.utils.auth import AuthContext, get_auth_context +from app.tests.utils.auth import TestAuthContext, get_auth_context def create_test_organization(db: Session) -> Organization: From 8ecf9ea7137fb861ec0eac6f7e6d421b7d7713d3 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Tue, 7 Oct 2025 11:59:12 +0530 Subject: [PATCH 14/37] API key creation to include user and project ID parameters, enhancing user-specific key management. --- backend/app/api/routes/api_keys.py | 6 ++++-- backend/app/crud/api_key.py | 19 ++++++++++++------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/backend/app/api/routes/api_keys.py b/backend/app/api/routes/api_keys.py index 970070ec..7bbbfd10 100644 --- a/backend/app/api/routes/api_keys.py +++ b/backend/app/api/routes/api_keys.py @@ -18,18 +18,20 @@ ) def create_api_key_route( project_id: int, + user_id: int, current_user: AuthContextDep, session: SessionDep, ): """ - Create a new API key for the current project. + Create a new API key for the project and user, Restricted to Superuser. The raw API key is returned only once during creation. Store it securely as it cannot be retrieved again. """ api_key_crud = APIKeyCrud(session=session, project_id=project_id) raw_key, api_key = api_key_crud.create( - user_id=current_user.id, + user_id=user_id, + project_id=project_id, ) api_key = APIKeyCreateResponse(**api_key.model_dump(), key=raw_key) diff --git a/backend/app/crud/api_key.py b/backend/app/crud/api_key.py index 99186cd6..9324e299 100644 --- a/backend/app/crud/api_key.py +++ b/backend/app/crud/api_key.py @@ -5,7 +5,7 @@ from sqlmodel import Session, select, and_ from fastapi import HTTPException -from app.models import APIKey +from app.models import APIKey, User from app.crud import get_project_by_id from app.core.util import now from app.core.security import api_key_manager @@ -52,23 +52,28 @@ def read_all(self, skip: int = 0, limit: int = 100) -> list[APIKey]: ) return self.session.exec(statement).all() - def create(self, user_id: int) -> Tuple[str, APIKey]: + def create(self, user_id: int, project_id: int) -> Tuple[str, APIKey]: """ Create a new API key for the project. """ try: - raw_key, key_prefix, key_hash = api_key_manager.generate() + project = get_project_by_id(session=self.session, project_id=project_id) + if not project: + raise HTTPException(status_code=404, detail="Project not found") - project = get_project_by_id( - session=self.session, project_id=self.project_id - ) + user = self.session.get(User, user_id) + + if not user: + raise HTTPException(status_code=404, detail="User not found") + + raw_key, key_prefix, key_hash = api_key_manager.generate() api_key = APIKey( key_prefix=key_prefix, key_hash=key_hash, user_id=user_id, organization_id=project.organization_id, - project_id=self.project_id, + project_id=project_id, ) self.session.add(api_key) From 60ba6f579aab3b06be178bbdd6a68f535b997654 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Tue, 7 Oct 2025 13:01:02 +0530 Subject: [PATCH 15/37] enhance test coverage for API key CRUD and Routes operations --- backend/app/crud/api_key.py | 14 +- backend/app/tests/api/routes/test_api_key.py | 114 ++++++++ backend/app/tests/crud/test_api_key.py | 267 +++++++++++++++++++ backend/app/tests/utils/test_data.py | 24 ++ 4 files changed, 412 insertions(+), 7 deletions(-) diff --git a/backend/app/crud/api_key.py b/backend/app/crud/api_key.py index 9324e299..20121631 100644 --- a/backend/app/crud/api_key.py +++ b/backend/app/crud/api_key.py @@ -56,16 +56,16 @@ def create(self, user_id: int, project_id: int) -> Tuple[str, APIKey]: """ Create a new API key for the project. """ - try: - project = get_project_by_id(session=self.session, project_id=project_id) - if not project: - raise HTTPException(status_code=404, detail="Project not found") + project = get_project_by_id(session=self.session, project_id=project_id) + if not project: + raise HTTPException(status_code=404, detail="Project not found") - user = self.session.get(User, user_id) + user = self.session.get(User, user_id) - if not user: - raise HTTPException(status_code=404, detail="User not found") + if not user: + raise HTTPException(status_code=404, detail="User not found") + try: raw_key, key_prefix, key_hash = api_key_manager.generate() api_key = APIKey( diff --git a/backend/app/tests/api/routes/test_api_key.py b/backend/app/tests/api/routes/test_api_key.py index e69de29b..ee3231c0 100644 --- a/backend/app/tests/api/routes/test_api_key.py +++ b/backend/app/tests/api/routes/test_api_key.py @@ -0,0 +1,114 @@ +from uuid import uuid4 + +from fastapi.testclient import TestClient +from sqlmodel import Session + +from app.core.config import settings +from app.tests.utils.auth import TestAuthContext +from app.tests.utils.test_data import create_test_api_key, create_test_project +from app.tests.utils.user import create_random_user + + +def test_create_api_key_as_superuser( + db: Session, + client: TestClient, + superuser_token_headers: dict[str, str], +) -> None: + """Test creating an API key as a superuser.""" + + user = create_random_user(db) + project = create_test_project(db) + + response = client.post( + f"{settings.API_V1_STR}/apikeys/", + headers=superuser_token_headers, + params={ + "project_id": project.id, + "user_id": user.id, + }, + ) + assert response.status_code == 201 + data = response.json() + assert data["success"] is True + assert "data" in data + assert "key" in data["data"] + assert "id" in data["data"] + assert "key_prefix" in data["data"] + assert data["data"]["project_id"] == project.id + assert data["data"]["user_id"] == user.id + assert data["data"]["organization_id"] == project.organization_id + assert data["data"]["key"].startswith("ApiKey ") + + +def test_create_api_key_as_normal_user_forbidden( + client: TestClient, + normal_user_token_headers: dict[str, str], + user_api_key: TestAuthContext, +) -> None: + """Test that normal users cannot create API keys (superuser only).""" + response = client.post( + f"{settings.API_V1_STR}/apikeys/", + headers=normal_user_token_headers, + params={ + "project_id": user_api_key.project_id, + "user_id": user_api_key.user_id, + }, + ) + assert response.status_code == 403 + + +def test_list_api_keys( + db: Session, + client: TestClient, + user_api_key: TestAuthContext, +) -> None: + """Test listing API keys as a normal user.""" + created_keys = [] + for _ in range(3): + key = create_test_api_key( + db=db, + project_id=user_api_key.project_id, + user_id=user_api_key.user_id, + ) + created_keys.append(key) + + response = client.get( + f"{settings.API_V1_STR}/apikeys/", + headers={"X-API-KEY": user_api_key.key}, + ) + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert isinstance(data["data"], list) + # Verify we have at least the 3 created keys + the fixture key (4 total) + assert len(data["data"]) >= 4 + + +def test_delete_api_key( + client: TestClient, + user_api_key: TestAuthContext, +) -> None: + """Test deleting an API key by its owner.""" + + delete_response = client.delete( + f"{settings.API_V1_STR}/apikeys/{user_api_key.api_key_id}", + headers={"X-API-KEY": user_api_key.key}, + ) + assert delete_response.status_code == 200 + data = delete_response.json() + assert data["success"] is True + assert "message" in data["data"] + assert "deleted successfully" in data["data"]["message"].lower() + + +def test_delete_api_key_nonexistent( + client: TestClient, + user_api_key: TestAuthContext, +) -> None: + """Test deleting a non-existent API key.""" + fake_uuid = uuid4() + response = client.delete( + f"{settings.API_V1_STR}/apikeys/{fake_uuid}", + headers={"X-API-KEY": user_api_key.key}, + ) + assert response.status_code == 404 diff --git a/backend/app/tests/crud/test_api_key.py b/backend/app/tests/crud/test_api_key.py index e69de29b..939837c4 100644 --- a/backend/app/tests/crud/test_api_key.py +++ b/backend/app/tests/crud/test_api_key.py @@ -0,0 +1,267 @@ +import pytest +from uuid import uuid4 +from sqlmodel import Session +from fastapi import HTTPException + +from app.crud import APIKeyCrud +from app.models import APIKey, Project, User +from app.tests.utils.test_data import create_test_project, create_test_api_key +from app.tests.utils.user import create_random_user +from app.tests.utils.utils import get_non_existent_id + + +def test_create_api_key(db: Session) -> None: + """Test creating a new API key""" + project = create_test_project(db) + user = create_random_user(db) + + api_key_crud = APIKeyCrud(session=db, project_id=project.id) + raw_key, api_key = api_key_crud.create(user_id=user.id, project_id=project.id) + + assert api_key.id is not None + assert api_key.user_id == user.id + assert api_key.project_id == project.id + assert api_key.organization_id == project.organization_id + assert api_key.key_prefix is not None + assert api_key.key_hash is not None + assert api_key.is_deleted is False + assert api_key.deleted_at is None + assert raw_key is not None + assert len(raw_key) > 0 + + +def test_create_api_key_with_nonexistent_project(db: Session) -> None: + """Test creating API key with a project that doesn't exist""" + user = create_random_user(db) + fake_project_id = get_non_existent_id(session=db, model=Project) + + api_key_crud = APIKeyCrud(session=db, project_id=fake_project_id) + + with pytest.raises(HTTPException) as exc_info: + api_key_crud.create(user_id=user.id, project_id=fake_project_id) + + assert exc_info.value.status_code == 404 + assert "Project not found" in str(exc_info.value.detail) + + +def test_create_api_key_with_nonexistent_user(db: Session) -> None: + """Test creating API key with a user that doesn't exist""" + project = create_test_project(db) + fake_user_id = get_non_existent_id(session=db, model=User) + + api_key_crud = APIKeyCrud(session=db, project_id=project.id) + + with pytest.raises(HTTPException) as exc_info: + api_key_crud.create(user_id=fake_user_id, project_id=project.id) + + assert exc_info.value.status_code == 404 + assert "User not found" in str(exc_info.value.detail) + + +def test_read_one_api_key(db: Session) -> None: + """Test reading a single API key by ID""" + api_key = create_test_api_key(db=db) + api_key_crud = APIKeyCrud(session=db, project_id=api_key.project_id) + retrieved_key = api_key_crud.read_one(key_id=api_key.id) + + assert retrieved_key is not None + assert retrieved_key.id == api_key.id + assert retrieved_key.key_prefix == api_key.key_prefix + assert retrieved_key.project_id == api_key.project_id + + +def test_read_one_api_key_nonexistent(db: Session) -> None: + """Test reading an API key that doesn't exist""" + project = create_test_project(db) + + api_key_crud = APIKeyCrud(session=db, project_id=project.id) + fake_key_id = uuid4() + + retrieved_key = api_key_crud.read_one(key_id=fake_key_id) + + assert retrieved_key is None + + +def test_read_one_api_key_wrong_project(db: Session) -> None: + """Test that reading an API key from a different project returns None""" + + api_key = create_test_api_key(db=db) + project2 = create_test_project(db) + + # Try to read it from project2 scope + api_key_crud2 = APIKeyCrud(session=db, project_id=project2.id) + retrieved_key = api_key_crud2.read_one(key_id=api_key.id) + + assert retrieved_key is None + + +def test_read_one_deleted_api_key(db: Session) -> None: + """Test that reading a deleted API key returns None""" + api_key = create_test_api_key(db=db) + + api_key_crud = APIKeyCrud(session=db, project_id=api_key.project_id) + + api_key_crud.delete(key_id=api_key.id) + retrieved_key = api_key_crud.read_one(key_id=api_key.id) + assert retrieved_key is None + + +def test_read_all_api_keys(db: Session) -> None: + """Test reading all API keys for a project""" + project = create_test_project(db) + user = create_random_user(db) + + api_key_crud = APIKeyCrud(session=db, project_id=project.id) + + # Create multiple API keys + key1 = create_test_api_key(db=db, project_id=project.id, user_id=user.id) + key2 = create_test_api_key(db=db, project_id=project.id, user_id=user.id) + key3 = create_test_api_key(db=db, project_id=project.id, user_id=user.id) + + # Read all keys + all_keys = api_key_crud.read_all() + + assert len(all_keys) == 3 + key_ids = {key.id for key in all_keys} + assert key1.id in key_ids + assert key2.id in key_ids + assert key3.id in key_ids + + +def test_read_all_api_keys_with_pagination(db: Session) -> None: + """Test reading API keys with skip and limit parameters""" + project = create_test_project(db) + user = create_random_user(db) + + api_key_crud = APIKeyCrud(session=db, project_id=project.id) + + # Create 5 API keys + for _ in range(5): + create_test_api_key(db=db, project_id=project.id, user_id=user.id) + + # Test pagination + page1 = api_key_crud.read_all(skip=0, limit=2) + assert len(page1) == 2 + + page2 = api_key_crud.read_all(skip=2, limit=2) + assert len(page2) == 2 + + page3 = api_key_crud.read_all(skip=4, limit=2) + assert len(page3) == 1 + + all_ids = ( + {key.id for key in page1} + | {key.id for key in page2} + | {key.id for key in page3} + ) + assert len(all_ids) == 5 + + +def test_read_all_excludes_deleted_keys(db: Session) -> None: + """Test that read_all excludes deleted API keys""" + project = create_test_project(db) + user = create_random_user(db) + + api_key_crud = APIKeyCrud(session=db, project_id=project.id) + + # Create 3 API keys + key1 = create_test_api_key(db=db, project_id=project.id, user_id=user.id) + key2 = create_test_api_key(db=db, project_id=project.id, user_id=user.id) + key3 = create_test_api_key(db=db, project_id=project.id, user_id=user.id) + + # Delete one + api_key_crud.delete(key_id=key2.id) + + # Read all should only return 2 + all_keys = api_key_crud.read_all() + assert len(all_keys) == 2 + key_ids = {key.id for key in all_keys} + assert key1.id in key_ids + assert key2.id not in key_ids + assert key3.id in key_ids + + +def test_read_all_scoped_to_project(db: Session) -> None: + """Test that read_all only returns keys for the specified project""" + project1 = create_test_project(db) + project2 = create_test_project(db) + user = create_random_user(db) + + api_key_crud1 = APIKeyCrud(session=db, project_id=project1.id) + api_key_crud2 = APIKeyCrud(session=db, project_id=project2.id) + + # Create keys for both projects + api_key_crud1.create(user_id=user.id, project_id=project1.id) + api_key_crud1.create(user_id=user.id, project_id=project1.id) + api_key_crud2.create(user_id=user.id, project_id=project2.id) + + # Each project should only see its own keys + project1_keys = api_key_crud1.read_all() + project2_keys = api_key_crud2.read_all() + + assert len(project1_keys) == 2 + assert len(project2_keys) == 1 + assert all(key.project_id == project1.id for key in project1_keys) + assert all(key.project_id == project2.id for key in project2_keys) + + +def test_delete_api_key(db: Session) -> None: + """Test soft deleting an API key""" + api_key = create_test_api_key(db=db) + + api_key_crud = APIKeyCrud(session=db, project_id=api_key.project_id) + + api_key_crud.delete(key_id=api_key.id) + + db_key = db.get(APIKey, api_key.id) + assert db_key is not None + assert db_key.is_deleted is True + assert db_key.deleted_at is not None + + retrieved_key = api_key_crud.read_one(key_id=api_key.id) + assert retrieved_key is None + + +def test_delete_nonexistent_api_key(db: Session) -> None: + """Test deleting an API key that doesn't exist""" + project = create_test_project(db) + + api_key_crud = APIKeyCrud(session=db, project_id=project.id) + fake_key_id = uuid4() + + with pytest.raises(HTTPException) as exc_info: + api_key_crud.delete(key_id=fake_key_id) + + assert exc_info.value.status_code == 404 + assert "API Key not found" in str(exc_info.value.detail) + + +def test_delete_api_key_from_wrong_project(db: Session) -> None: + """Test that deleting an API key from a different project fails""" + api_key = create_test_api_key(db=db) + project2 = create_test_project(db) + + api_key_crud2 = APIKeyCrud(session=db, project_id=project2.id) + with pytest.raises(HTTPException) as exc_info: + api_key_crud2.delete(key_id=api_key.id) + + assert exc_info.value.status_code == 404 + assert "API Key not found" in str(exc_info.value.detail) + + db_key = db.get(APIKey, api_key.id) + assert db_key is not None + assert db_key.is_deleted is False + + +def test_delete_already_deleted_api_key(db: Session) -> None: + """Test deleting an API key that's already deleted""" + api_key = create_test_api_key(db=db) + api_key_crud = APIKeyCrud(session=db, project_id=api_key.project_id) + + api_key_crud.delete(key_id=api_key.id) + + with pytest.raises(HTTPException) as exc_info: + api_key_crud.delete(key_id=api_key.id) + + assert exc_info.value.status_code == 404 + assert "API Key not found" in str(exc_info.value.detail) diff --git a/backend/app/tests/utils/test_data.py b/backend/app/tests/utils/test_data.py index 3acabf63..079204e3 100644 --- a/backend/app/tests/utils/test_data.py +++ b/backend/app/tests/utils/test_data.py @@ -4,6 +4,7 @@ Organization, Project, APIKey, + APIKeyCreateResponse, Credential, OrganizationCreate, ProjectCreate, @@ -84,6 +85,29 @@ def test_credential_data(db: Session) -> CredsCreate: return creds_data +def create_test_api_key( + db: Session, + project_id: int | None = None, + user_id: int | None = None, +) -> APIKeyCreateResponse: + """ + Creates and returns a test API key for a specific project and user. + + Persists the API key to the database. + """ + if user_id is None: + user = create_random_user(db) + user_id = user.id + + if project_id is None: + project = create_test_project(db) + project_id = project.id + + api_key_crud = APIKeyCrud(session=db, project_id=project_id) + raw_key, api_key = api_key_crud.create(user_id=user_id, project_id=project_id) + return api_key + + def create_test_credential(db: Session) -> tuple[list[Credential], Project]: """ Creates and returns a test credential for a test project. From 1cb5091db16daf896681b3fa6a953a05affa2ad6 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Tue, 7 Oct 2025 13:14:55 +0530 Subject: [PATCH 16/37] Add tests for API key manager --- backend/app/tests/core/test_security.py | 187 ++++++++++++++++++++++++ backend/app/tests/utils/test_data.py | 2 +- 2 files changed, 188 insertions(+), 1 deletion(-) diff --git a/backend/app/tests/core/test_security.py b/backend/app/tests/core/test_security.py index 4f7f6861..88e66bea 100644 --- a/backend/app/tests/core/test_security.py +++ b/backend/app/tests/core/test_security.py @@ -1,11 +1,16 @@ import pytest +from uuid import uuid4 +from sqlmodel import Session from app.core.security import ( get_password_hash, verify_password, encrypt_api_key, decrypt_api_key, get_encryption_key, + APIKeyManager, ) +from app.models import APIKey, User, Organization, Project, AuthContext +from app.tests.utils.test_data import create_test_api_key def test_encrypt_decrypt_api_key(): @@ -119,3 +124,185 @@ def test_get_encryption_key(): assert isinstance(key, bytes) # The key is base64 encoded, so it should be 44 bytes assert len(key) == 44 # Base64 encoded Fernet key length is 44 bytes + + +class TestAPIKeyManager: + """Test suite for APIKeyManager class.""" + + def test_generate_returns_correct_tuple(self): + """Test that generate returns a tuple of (raw_key, key_prefix, key_hash).""" + raw_key, key_prefix, key_hash = APIKeyManager.generate() + + assert isinstance(raw_key, str) + assert isinstance(key_prefix, str) + assert isinstance(key_hash, str) + + def test_generate_raw_key_format(self): + """Test that generated raw key has correct format.""" + raw_key, key_prefix, key_hash = APIKeyManager.generate() + + # Should start with "ApiKey " + assert raw_key.startswith(APIKeyManager.PREFIX_NAME) + + # Should have correct length (7 for "ApiKey " + 22 for prefix + 43 for secret) + expected_length = len(APIKeyManager.PREFIX_NAME) + APIKeyManager.KEY_LENGTH + assert len(raw_key) == expected_length + + def test_generate_key_prefix_length(self): + """Test that generated key prefix has correct length.""" + raw_key, key_prefix, key_hash = APIKeyManager.generate() + + assert len(key_prefix) == APIKeyManager.PREFIX_LENGTH + + def test_generate_unique_keys(self): + """Test that generate creates unique keys on each call.""" + raw_key1, prefix1, hash1 = APIKeyManager.generate() + raw_key2, prefix2, hash2 = APIKeyManager.generate() + + assert raw_key1 != raw_key2 + assert prefix1 != prefix2 + assert hash1 != hash2 + + def test_generate_hash_is_bcrypt(self): + """Test that the generated hash uses bcrypt format.""" + raw_key, key_prefix, key_hash = APIKeyManager.generate() + + # bcrypt hashes start with $2b$ (or $2a$ or $2y$) + assert key_hash.startswith("$2") + + def test_extract_key_parts_new_format(self): + """Test extracting key parts from new format (65 chars).""" + # Generate a test key with known format + raw_key, expected_prefix, _ = APIKeyManager.generate() + + # Extract parts + result = APIKeyManager._extract_key_parts(raw_key) + + assert result is not None + extracted_prefix, secret = result + assert extracted_prefix == expected_prefix + assert len(secret) == APIKeyManager.KEY_LENGTH - APIKeyManager.PREFIX_LENGTH + + def test_extract_key_parts_old_format(self): + """Test extracting key parts from old format (43 chars).""" + old_prefix = "a" * 12 + old_secret = "b" * 31 + raw_key = f"{APIKeyManager.PREFIX_NAME}{old_prefix}{old_secret}" + + result = APIKeyManager._extract_key_parts(raw_key) + + assert result is not None + extracted_prefix, secret = result + assert extracted_prefix == old_prefix + assert secret == old_secret + + def test_extract_key_parts_invalid_prefix(self): + """Test that invalid prefix returns None.""" + invalid_key = "InvalidPrefix abcdefghij1234567890" + + result = APIKeyManager._extract_key_parts(invalid_key) + + assert result is None + + def test_extract_key_parts_invalid_length(self): + """Test that invalid length returns None.""" + # Too short + invalid_key = f"{APIKeyManager.PREFIX_NAME}tooshort" + + result = APIKeyManager._extract_key_parts(invalid_key) + + assert result is None + + def test_verify_valid_key(self, db: Session): + """Test verifying a valid API key.""" + api_key = create_test_api_key(db) + + # Verify the key + auth_context = APIKeyManager.verify(db, api_key.key) + + user = db.get(User, api_key.user_id) + organization = db.get(Organization, api_key.organization_id) + project = db.get(Project, api_key.project_id) + + assert auth_context is not None + assert isinstance(auth_context, AuthContext) + assert auth_context.user_id == api_key.user_id + assert auth_context.organization_id == api_key.organization_id + assert auth_context.project_id == api_key.project_id + assert auth_context.user == user + assert auth_context.organization == organization + assert auth_context.project == project + + def test_verify_invalid_key(self, db: Session): + """Test verifying an invalid API key.""" + # Generate a key but don't store it + raw_key, _, _ = APIKeyManager.generate() + + # Verify should return None + auth_context = APIKeyManager.verify(db, raw_key) + + assert auth_context is None + + def test_verify_wrong_secret(self, db: Session): + """Test verifying with correct prefix but wrong secret.""" + create_test_api_key(db) + + # Generate a different key to try verification + raw_key2, _, _ = APIKeyManager.generate() + + # Try to verify with key2 (wrong secret) + auth_context = APIKeyManager.verify(db, raw_key2) + + assert auth_context is None + + def test_verify_deleted_key(self, db: Session): + """Test that deleted API keys cannot be verified.""" + # Create test API key + api_key_response = create_test_api_key(db) + raw_key = api_key_response.key + + # Mark the API key as deleted + api_key = db.get(APIKey, api_key_response.id) + api_key.is_deleted = True + db.commit() + + # Verify should return None for deleted key + auth_context = APIKeyManager.verify(db, raw_key) + + assert auth_context is None + + def test_verify_malformed_key(self, db: Session): + """Test verifying with malformed key format.""" + malformed_keys = [ + "not_an_api_key", + "", + "ApiKey", + "ApiKey ", + None, + ] + + for malformed_key in malformed_keys: + if malformed_key is not None: + auth_context = APIKeyManager.verify(db, malformed_key) + assert auth_context is None + + def test_prefix_name_constant(self): + """Test that PREFIX_NAME is correct.""" + assert APIKeyManager.PREFIX_NAME == "ApiKey " + + def test_key_length_constants(self): + """Test that key length constants are correct.""" + assert APIKeyManager.PREFIX_LENGTH == 22 + assert APIKeyManager.KEY_LENGTH == 65 + assert APIKeyManager.KEY_LENGTH == APIKeyManager.PREFIX_LENGTH + 43 + + def test_generate_creates_verifiable_key(self, db: Session): + """Integration test: generated key can be verified.""" + # Create test API key + api_key_response = create_test_api_key(db) + + # Verify the key works + auth_context = APIKeyManager.verify(db, api_key_response.key) + + assert auth_context is not None + assert auth_context.user_id == api_key_response.user_id diff --git a/backend/app/tests/utils/test_data.py b/backend/app/tests/utils/test_data.py index 079204e3..76429eb8 100644 --- a/backend/app/tests/utils/test_data.py +++ b/backend/app/tests/utils/test_data.py @@ -105,7 +105,7 @@ def create_test_api_key( api_key_crud = APIKeyCrud(session=db, project_id=project_id) raw_key, api_key = api_key_crud.create(user_id=user_id, project_id=project_id) - return api_key + return APIKeyCreateResponse(key=raw_key, **api_key.dict()) def create_test_credential(db: Session) -> tuple[list[Credential], Project]: From 520af6d6f141fd69e3f1bd3c88c75c74df6a1282 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Tue, 7 Oct 2025 13:20:22 +0530 Subject: [PATCH 17/37] authentication context functions for consistency in test files --- backend/app/api/deps.py | 21 +++++++++---------- backend/app/tests/conftest.py | 12 +++++------ .../doctransformer/test_service/conftest.py | 2 +- backend/app/tests/utils/auth.py | 10 ++++----- backend/app/tests/utils/test_data.py | 1 - 5 files changed, 22 insertions(+), 24 deletions(-) diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index 70ee5715..24c873cd 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -1,10 +1,9 @@ from collections.abc import Generator -from typing import Annotated, Optional +from typing import Annotated import jwt -from fastapi import Depends, HTTPException, status, Request, Header, Security -from fastapi.responses import JSONResponse -from fastapi.security import OAuth2PasswordBearer, APIKeyHeader +from fastapi import Depends, HTTPException, Request, status +from fastapi.security import APIKeyHeader, OAuth2PasswordBearer from jwt.exceptions import InvalidTokenError from pydantic import ValidationError from sqlmodel import Session, select @@ -13,19 +12,19 @@ from app.core.config import settings from app.core.db import engine from app.core.security import api_key_manager -from app.utils import APIResponse from app.crud.organization import validate_organization from app.models import ( AuthContext, + Organization, + Project, + ProjectUser, TokenPayload, User, - UserProjectOrg, UserOrganization, - ProjectUser, - Project, - Organization, + UserProjectOrg, ) + reusable_oauth2 = OAuth2PasswordBearer( tokenUrl=f"{settings.API_V1_STR}/login/access-token", auto_error=False ) @@ -151,7 +150,7 @@ def get_current_active_superuser_org(current_user: CurrentUserOrg) -> User: return current_user -def get_user_context( +def get_auth_context( session: SessionDep, token: TokenDep, api_key: Annotated[str, Depends(api_key_header)], @@ -205,7 +204,7 @@ def get_user_context( raise HTTPException(status_code=401, detail="Invalid Authorization format") -AuthContextDep = Annotated[AuthContext, Depends(get_user_context)] +AuthContextDep = Annotated[AuthContext, Depends(get_auth_context)] def verify_user_project_organization( diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index c9e0597a..73277dcc 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -17,8 +17,8 @@ from app.tests.utils.user import authentication_token_from_email from app.tests.utils.utils import get_superuser_token_headers from app.tests.utils.auth import ( - get_superuser_auth_context, - get_user_auth_context, + get_superuser_test_auth_context, + get_user_test_auth_context, TestAuthContext, ) from app.seed_data.seed_data import seed_database @@ -73,23 +73,23 @@ def normal_user_token_headers(client: TestClient, db: Session) -> dict[str, str] @pytest.fixture def superuser_api_key_header(db: Session) -> dict[str, str]: - auth_ctx = get_superuser_auth_context(db) + auth_ctx = get_superuser_test_auth_context(db) return {"X-API-KEY": auth_ctx.key} @pytest.fixture def user_api_key_header(db: Session) -> dict[str, str]: - auth_ctx = get_user_auth_context(db) + auth_ctx = get_user_test_auth_context(db) return {"X-API-KEY": auth_ctx.key} @pytest.fixture def superuser_api_key(db: Session) -> TestAuthContext: - auth_ctx = get_superuser_auth_context(db) + auth_ctx = get_superuser_test_auth_context(db) return auth_ctx @pytest.fixture def user_api_key(db: Session) -> TestAuthContext: - auth_ctx = get_user_auth_context(db) + auth_ctx = get_user_test_auth_context(db) return auth_ctx diff --git a/backend/app/tests/core/doctransformer/test_service/conftest.py b/backend/app/tests/core/doctransformer/test_service/conftest.py index 6fe7aa76..75574be0 100644 --- a/backend/app/tests/core/doctransformer/test_service/conftest.py +++ b/backend/app/tests/core/doctransformer/test_service/conftest.py @@ -16,7 +16,7 @@ from app.core.config import settings from app.models import Document, Project, UserProjectOrg from app.tests.utils.document import DocumentStore -from app.tests.utils.auth import get_user_auth_context, TestAuthContext +from app.tests.utils.auth import TestAuthContext @pytest.fixture(scope="class") diff --git a/backend/app/tests/utils/auth.py b/backend/app/tests/utils/auth.py index c314c67e..1f35becb 100644 --- a/backend/app/tests/utils/auth.py +++ b/backend/app/tests/utils/auth.py @@ -22,7 +22,7 @@ class TestAuthContext(SQLModel): api_key: APIKey -def get_auth_context( +def get_test_auth_context( session: Session, user_email: str, project_name: str, @@ -92,7 +92,7 @@ def get_auth_context( ) -def get_superuser_auth_context(session: Session) -> TestAuthContext: +def get_superuser_test_auth_context(session: Session) -> TestAuthContext: """ Get authentication context for superuser from seeded data. @@ -110,7 +110,7 @@ def get_superuser_auth_context(session: Session) -> TestAuthContext: Raises: ValueError: If the required data is not found in the database """ - return get_auth_context( + return get_test_auth_context( session=session, user_email=settings.FIRST_SUPERUSER, project_name="Glific", @@ -119,7 +119,7 @@ def get_superuser_auth_context(session: Session) -> TestAuthContext: ) -def get_user_auth_context(session: Session) -> TestAuthContext: +def get_user_test_auth_context(session: Session) -> TestAuthContext: """ Get authentication context for normal user from seeded data. @@ -137,7 +137,7 @@ def get_user_auth_context(session: Session) -> TestAuthContext: Raises: ValueError: If the required data is not found in the database """ - return get_auth_context( + return get_test_auth_context( session=session, user_email=settings.EMAIL_TEST_USER, project_name="Dalgo", diff --git a/backend/app/tests/utils/test_data.py b/backend/app/tests/utils/test_data.py index 76429eb8..a5ba47f4 100644 --- a/backend/app/tests/utils/test_data.py +++ b/backend/app/tests/utils/test_data.py @@ -31,7 +31,6 @@ get_document, get_project, ) -from app.tests.utils.auth import TestAuthContext, get_auth_context def create_test_organization(db: Session) -> Organization: From d94adecb55dd6e54ef45373dff6db7c105c74bd5 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Tue, 7 Oct 2025 13:27:18 +0530 Subject: [PATCH 18/37] replace APIKeyCrud with create_test_api_key for consistency in test setup --- backend/app/tests/utils/collection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/app/tests/utils/collection.py b/backend/app/tests/utils/collection.py index 111e8243..91382b81 100644 --- a/backend/app/tests/utils/collection.py +++ b/backend/app/tests/utils/collection.py @@ -8,7 +8,7 @@ from app.models import Collection, Organization, Project from app.tests.utils.utils import get_user_id_by_email from app.tests.utils.test_data import create_test_project -from app.crud import APIKeyCrud +from app.tests.utils.test_data import create_test_api_key class constants: @@ -29,7 +29,7 @@ def get_collection(db: Session, client=None, owner_id: int = None) -> Collection project = create_test_project(db) # Step 2: Create API key for user with valid foreign keys - APIKeyCrud(session=db, project_id=project.id).create(user_id=owner_id) + create_test_api_key(db, user_id=owner_id, project_id=project.id) if client is None: client = OpenAI(api_key="test_api_key") From 557c143991dc857f6cf48902e2efe7bbbd6f430b Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Tue, 7 Oct 2025 13:44:51 +0530 Subject: [PATCH 19/37] Add tests for get_user_context --- backend/app/api/deps.py | 8 +- backend/app/tests/api/test_deps.py | 162 ++++++++++++++++++++++++++++- 2 files changed, 165 insertions(+), 5 deletions(-) diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index 24c873cd..1fc5d619 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -166,13 +166,13 @@ def get_auth_context( raise HTTPException(status_code=401, detail="Invalid API Key") if not auth_context.user.is_active: - raise HTTPException(status_code=400, detail="Inactive user") + raise HTTPException(status_code=403, detail="Inactive user") if not auth_context.organization.is_active: - raise HTTPException(status_code=400, detail="Inactive Organization") + raise HTTPException(status_code=403, detail="Inactive Organization") if not auth_context.project.is_active: - raise HTTPException(status_code=400, detail="Inactive Project") + raise HTTPException(status_code=403, detail="Inactive Project") return auth_context @@ -192,7 +192,7 @@ def get_auth_context( if not user: raise HTTPException(status_code=404, detail="User not found") if not user.is_active: - raise HTTPException(status_code=400, detail="Inactive user") + raise HTTPException(status_code=403, detail="Inactive user") auth_context = AuthContext( user_id=user.id, diff --git a/backend/app/tests/api/test_deps.py b/backend/app/tests/api/test_deps.py index 64280ebb..64dd4643 100644 --- a/backend/app/tests/api/test_deps.py +++ b/backend/app/tests/api/test_deps.py @@ -2,7 +2,7 @@ import uuid from sqlmodel import Session, select from fastapi import HTTPException -from app.api.deps import verify_user_project_organization +from app.api.deps import verify_user_project_organization, get_auth_context from app.models import ( User, Organization, @@ -10,9 +10,14 @@ ProjectUser, UserProjectOrg, UserOrganization, + AuthContext, ) from app.tests.utils.utils import random_email from app.core.security import get_password_hash +from app.tests.utils.auth import TestAuthContext +from app.tests.utils.user import authentication_token_from_email, create_random_user +from app.core.config import settings +from app.tests.utils.test_data import create_test_api_key, create_test_project def create_org_project( @@ -159,3 +164,158 @@ def test_verify_inactive_project(db: Session): assert exc_info.value.status_code == 400 assert exc_info.value.detail == "Project is not active" + + +class TestGetAuthContext: + """Test suite for get_auth_context function""" + + def test_get_auth_context_with_valid_api_key( + self, db: Session, user_api_key: TestAuthContext + ): + """Test successful authentication with valid API key""" + auth_context = get_auth_context( + session=db, + token=None, + api_key=user_api_key.key, + ) + + assert isinstance(auth_context, AuthContext) + assert auth_context.user == user_api_key.user + assert auth_context.project == user_api_key.project + assert auth_context.organization == user_api_key.organization + + def test_get_auth_context_with_invalid_api_key(self, db: Session): + """Test authentication fails with invalid API key""" + invalid_api_key = "ApiKey InvalidKeyThatDoesNotExist123456789" + + with pytest.raises(HTTPException) as exc_info: + get_auth_context( + session=db, + token=None, + api_key=invalid_api_key, + ) + + assert exc_info.value.status_code == 401 + assert exc_info.value.detail == "Invalid API Key" + + def test_get_auth_context_with_valid_token( + self, db: Session, normal_user_token_headers: dict[str, str] + ): + """Test successful authentication with valid token""" + token = normal_user_token_headers["Authorization"].replace("Bearer ", "") + auth_context = get_auth_context( + session=db, + token=token, + api_key=None, + ) + + # Assert + assert isinstance(auth_context, AuthContext) + assert auth_context.user.email == settings.EMAIL_TEST_USER + + def test_get_auth_context_with_invalid_token(self, db: Session): + """Test authentication fails with invalid token""" + invalid_token = "invalid.token" + + with pytest.raises(HTTPException) as exc_info: + get_auth_context( + session=db, + token=invalid_token, + api_key=None, + ) + + assert exc_info.value.status_code == 403 + + def test_get_auth_context_with_no_credentials(self, db: Session): + """Test authentication fails when neither API key nor token is provided""" + with pytest.raises(HTTPException) as exc_info: + get_auth_context( + session=db, + token=None, + api_key=None, + ) + + assert exc_info.value.status_code == 401 + assert exc_info.value.detail == "Invalid Authorization format" + + def test_get_auth_context_with_inactive_user_via_api_key(self, db: Session): + """Test authentication fails when API key belongs to inactive user""" + api_key = create_test_api_key(db) + + user = db.get(User, api_key.user_id) + user.is_active = False + db.add(user) + db.commit() + db.refresh(user) + + with pytest.raises(HTTPException) as exc_info: + get_auth_context( + session=db, + token=None, + api_key=api_key.key, + ) + + assert exc_info.value.status_code == 403 + assert exc_info.value.detail == "Inactive user" + + def test_get_auth_context_with_inactive_user_via_token(self, db: Session, client): + """Test authentication fails when token belongs to inactive user""" + user = create_random_user(db) + token_headers = authentication_token_from_email( + client=client, email=user.email, db=db + ) + token = token_headers["Authorization"].replace("Bearer ", "") + + user.is_active = False + db.add(user) + db.commit() + + with pytest.raises(HTTPException) as exc_info: + get_auth_context( + session=db, + token=token, + api_key=None, + ) + + assert exc_info.value.status_code == 403 + assert exc_info.value.detail == "Inactive user" + + def test_get_auth_context_with_inactive_organization( + self, db: Session, user_api_key: TestAuthContext + ): + """Test authentication fails when organization is inactive""" + organization = user_api_key.organization + organization.is_active = False + db.add(organization) + db.commit() + db.refresh(organization) + + with pytest.raises(HTTPException) as exc_info: + get_auth_context( + session=db, + token=None, + api_key=user_api_key.key, + ) + + assert exc_info.value.status_code == 403 + assert exc_info.value.detail == "Inactive Organization" + + def test_get_auth_context_with_inactive_project( + self, db: Session, user_api_key: TestAuthContext + ): + """Test authentication fails when project is inactive""" + project = user_api_key.project + project.is_active = False + db.add(project) + db.commit() + db.refresh(project) + + with pytest.raises(HTTPException) as exc_info: + get_auth_context( + session=db, + token=None, + api_key=user_api_key.key, + ) + + assert exc_info.value.status_code == 403 + assert exc_info.value.detail == "Inactive Project" From b94c1d951cc3ffa5ed00a2e3f4ca4f1e3680d3f0 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Tue, 7 Oct 2025 14:07:27 +0530 Subject: [PATCH 20/37] Add tests for permission checks and permission enum functionality --- backend/app/tests/api/test_permissions.py | 134 ++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 backend/app/tests/api/test_permissions.py diff --git a/backend/app/tests/api/test_permissions.py b/backend/app/tests/api/test_permissions.py new file mode 100644 index 00000000..70cfa3b7 --- /dev/null +++ b/backend/app/tests/api/test_permissions.py @@ -0,0 +1,134 @@ +import pytest +from fastapi import HTTPException +from sqlmodel import Session + +from app.models import User +from app.api.permissions import Permission, has_permission, require_permission +from app.api.deps import get_auth_context +from app.tests.utils.test_data import create_test_api_key + + +class TestHasPermission: + """Test suite for has_permission function""" + + def test_superuser_permission_with_superuser(self, db: Session): + """Test that superuser has SUPERUSER permission""" + api_key_response = create_test_api_key(db) + user = db.get(User, api_key_response.user_id) + user.is_superuser = True + db.add(user) + db.commit() + db.refresh(user) + + auth_context = get_auth_context(session=db, token=None, api_key=api_key_response.key) + + result = has_permission(auth_context, Permission.SUPERUSER, db) + + assert result is True + + def test_superuser_permission_with_regular_user(self, db: Session): + """Test that regular user does not have SUPERUSER permission""" + api_key_response = create_test_api_key(db) + + auth_context = get_auth_context(session=db, token=None, api_key=api_key_response.key) + + result = has_permission(auth_context, Permission.SUPERUSER, db) + + assert result is False + + def test_require_organization_permission_with_organization(self, db: Session): + """Test that user with organization has REQUIRE_ORGANIZATION permission""" + api_key_response = create_test_api_key(db) + + auth_context = get_auth_context(session=db, token=None, api_key=api_key_response.key) + + result = has_permission(auth_context, Permission.REQUIRE_ORGANIZATION, db) + + assert result is True + + def test_require_organization_permission_without_organization(self, db: Session): + """Test that user without organization does not have REQUIRE_ORGANIZATION permission""" + api_key_response = create_test_api_key(db) + + auth_context = get_auth_context(session=db, token=None, api_key=api_key_response.key) + + auth_context.organization_id = None + auth_context.organization = None + + result = has_permission(auth_context, Permission.REQUIRE_ORGANIZATION, db) + + assert result is False + + def test_require_project_permission_with_project(self, db: Session): + """Test that user with project has REQUIRE_PROJECT permission""" + api_key_response = create_test_api_key(db) + + auth_context = get_auth_context(session=db, token=None, api_key=api_key_response.key) + + result = has_permission(auth_context, Permission.REQUIRE_PROJECT, db) + + assert result is True + + def test_require_project_permission_without_project(self, db: Session): + """Test that user without project does not have REQUIRE_PROJECT permission""" + api_key_response = create_test_api_key(db) + + auth_context = get_auth_context(session=db, token=None, api_key=api_key_response.key) + + auth_context.project_id = None + auth_context.project = None + + result = has_permission(auth_context, Permission.REQUIRE_PROJECT, db) + + assert result is False + + +class TestRequirePermission: + """Test suite for require_permission dependency factory""" + + def test_returns_valid_permission_checker(self): + """Test that require_permission returns a valid callable permission checker""" + permission_checker = require_permission(Permission.SUPERUSER) + + assert callable(permission_checker) + + def test_permission_checker_passes_with_valid_permission(self, db: Session): + """Test that permission checker passes when user has required permission""" + api_key_response = create_test_api_key(db) + user = db.get(User, api_key_response.user_id) + user.is_superuser = True + db.add(user) + db.commit() + db.refresh(user) + auth_context = get_auth_context(session=db, token=None, api_key=api_key_response.key) + + permission_checker = require_permission(Permission.SUPERUSER) + permission_checker(auth_context, db) + + def test_permission_checker_raises_403_without_permission(self, db: Session): + """Test that permission checker raises HTTPException with 403 when user lacks permission""" + api_key_response = create_test_api_key(db) + auth_context = get_auth_context(session=db, token=None, api_key=api_key_response.key) + + permission_checker = require_permission(Permission.SUPERUSER) + + with pytest.raises(HTTPException) as exc_info: + permission_checker(auth_context, db) + + assert exc_info.value.status_code == 403 + + +class TestPermissionEnum: + """Test suite for Permission enum""" + + def test_permission_enum_values(self): + """Test that Permission enum has expected values""" + assert Permission.SUPERUSER.value == "require_superuser" + assert Permission.REQUIRE_ORGANIZATION.value == "require_organization_id" + assert Permission.REQUIRE_PROJECT.value == "require_project_id" + + def test_permission_enum_is_string(self): + """Test that Permission enum members are strings""" + assert isinstance(Permission.SUPERUSER, str) + assert isinstance(Permission.REQUIRE_ORGANIZATION, str) + assert isinstance(Permission.REQUIRE_PROJECT, str) From 1c128e99a07be9ff251bb4899662cba6f189c84c Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Tue, 7 Oct 2025 14:07:55 +0530 Subject: [PATCH 21/37] pre commt --- backend/app/tests/api/test_permissions.py | 32 +++++++++++++++++------ 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/backend/app/tests/api/test_permissions.py b/backend/app/tests/api/test_permissions.py index 70cfa3b7..1d427d42 100644 --- a/backend/app/tests/api/test_permissions.py +++ b/backend/app/tests/api/test_permissions.py @@ -20,7 +20,9 @@ def test_superuser_permission_with_superuser(self, db: Session): db.commit() db.refresh(user) - auth_context = get_auth_context(session=db, token=None, api_key=api_key_response.key) + auth_context = get_auth_context( + session=db, token=None, api_key=api_key_response.key + ) result = has_permission(auth_context, Permission.SUPERUSER, db) @@ -30,7 +32,9 @@ def test_superuser_permission_with_regular_user(self, db: Session): """Test that regular user does not have SUPERUSER permission""" api_key_response = create_test_api_key(db) - auth_context = get_auth_context(session=db, token=None, api_key=api_key_response.key) + auth_context = get_auth_context( + session=db, token=None, api_key=api_key_response.key + ) result = has_permission(auth_context, Permission.SUPERUSER, db) @@ -40,7 +44,9 @@ def test_require_organization_permission_with_organization(self, db: Session): """Test that user with organization has REQUIRE_ORGANIZATION permission""" api_key_response = create_test_api_key(db) - auth_context = get_auth_context(session=db, token=None, api_key=api_key_response.key) + auth_context = get_auth_context( + session=db, token=None, api_key=api_key_response.key + ) result = has_permission(auth_context, Permission.REQUIRE_ORGANIZATION, db) @@ -50,7 +56,9 @@ def test_require_organization_permission_without_organization(self, db: Session) """Test that user without organization does not have REQUIRE_ORGANIZATION permission""" api_key_response = create_test_api_key(db) - auth_context = get_auth_context(session=db, token=None, api_key=api_key_response.key) + auth_context = get_auth_context( + session=db, token=None, api_key=api_key_response.key + ) auth_context.organization_id = None auth_context.organization = None @@ -63,7 +71,9 @@ def test_require_project_permission_with_project(self, db: Session): """Test that user with project has REQUIRE_PROJECT permission""" api_key_response = create_test_api_key(db) - auth_context = get_auth_context(session=db, token=None, api_key=api_key_response.key) + auth_context = get_auth_context( + session=db, token=None, api_key=api_key_response.key + ) result = has_permission(auth_context, Permission.REQUIRE_PROJECT, db) @@ -73,7 +83,9 @@ def test_require_project_permission_without_project(self, db: Session): """Test that user without project does not have REQUIRE_PROJECT permission""" api_key_response = create_test_api_key(db) - auth_context = get_auth_context(session=db, token=None, api_key=api_key_response.key) + auth_context = get_auth_context( + session=db, token=None, api_key=api_key_response.key + ) auth_context.project_id = None auth_context.project = None @@ -100,7 +112,9 @@ def test_permission_checker_passes_with_valid_permission(self, db: Session): db.add(user) db.commit() db.refresh(user) - auth_context = get_auth_context(session=db, token=None, api_key=api_key_response.key) + auth_context = get_auth_context( + session=db, token=None, api_key=api_key_response.key + ) permission_checker = require_permission(Permission.SUPERUSER) permission_checker(auth_context, db) @@ -108,7 +122,9 @@ def test_permission_checker_passes_with_valid_permission(self, db: Session): def test_permission_checker_raises_403_without_permission(self, db: Session): """Test that permission checker raises HTTPException with 403 when user lacks permission""" api_key_response = create_test_api_key(db) - auth_context = get_auth_context(session=db, token=None, api_key=api_key_response.key) + auth_context = get_auth_context( + session=db, token=None, api_key=api_key_response.key + ) permission_checker = require_permission(Permission.SUPERUSER) From fe288b39731eb4992b2b28dfa70a02074e51e964 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Wed, 8 Oct 2025 09:02:51 +0530 Subject: [PATCH 22/37] Fix read_one method docstring and update logging for API key creation --- backend/app/crud/api_key.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/backend/app/crud/api_key.py b/backend/app/crud/api_key.py index 20121631..a03c0a18 100644 --- a/backend/app/crud/api_key.py +++ b/backend/app/crud/api_key.py @@ -24,7 +24,7 @@ def __init__(self, session: Session, project_id: int): def read_one(self, key_id: UUID) -> APIKey | None: """ - Retrieve a single non-deleted API key by its key_prefix. + Retrieve a single non-deleted API key by its id. """ statement = select(APIKey).where( and_( @@ -82,7 +82,7 @@ def create(self, user_id: int, project_id: int) -> Tuple[str, APIKey]: logger.info( f"[APIKeyCrud.create_api_key] API key created successfully | " - f"{{'api_key_id': '{api_key.id}', 'project_id': {self.project_id}, 'user_id': {user_id}}}" + f"{{'api_key_id': '{api_key.id}', 'project_id': {project_id}, 'user_id': {user_id}}}" ) return raw_key, api_key @@ -90,10 +90,9 @@ def create(self, user_id: int, project_id: int) -> Tuple[str, APIKey]: except Exception as e: logger.error( f"[APIKeyCrud.create_api_key] Failed to create API key | " - f"{{'project_id': {self.project_id}, 'user_id': {user_id}, 'error': '{str(e)}'}}", + f"{{'project_id': {project_id}, 'user_id': {user_id}, 'error': '{str(e)}'}}", exc_info=True, ) - self.session.rollback() raise HTTPException( status_code=500, detail=f"Failed to create API key: {str(e)}" ) From 76ac22e031950890e96de350be3981327289e1b3 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Wed, 8 Oct 2025 09:05:35 +0530 Subject: [PATCH 23/37] Update API key deletion to set updated_at timestamp and modify user_id field in AuthContext --- backend/app/crud/api_key.py | 1 + backend/app/models/auth.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/backend/app/crud/api_key.py b/backend/app/crud/api_key.py index a03c0a18..eb7632e3 100644 --- a/backend/app/crud/api_key.py +++ b/backend/app/crud/api_key.py @@ -107,6 +107,7 @@ def delete(self, key_id: UUID) -> None: api_key.is_deleted = True api_key.deleted_at = now() + api_key.updated_at = now() self.session.add(api_key) self.session.commit() self.session.refresh(api_key) diff --git a/backend/app/models/auth.py b/backend/app/models/auth.py index bfad680b..db40fa9d 100644 --- a/backend/app/models/auth.py +++ b/backend/app/models/auth.py @@ -16,7 +16,7 @@ class TokenPayload(SQLModel): class AuthContext(SQLModel): - user_id: int = Field(foreign_key="user.id") + user_id: int project_id: int | None = None organization_id: int | None = None From 786992e2dcd50030cfbb5f8d7d7895176d7a218c Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Wed, 8 Oct 2025 12:09:27 +0530 Subject: [PATCH 24/37] Fix import statement for get_project_by_id in APIKeyCrud --- backend/app/crud/api_key.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/app/crud/api_key.py b/backend/app/crud/api_key.py index eb7632e3..374b496e 100644 --- a/backend/app/crud/api_key.py +++ b/backend/app/crud/api_key.py @@ -6,7 +6,7 @@ from fastapi import HTTPException from app.models import APIKey, User -from app.crud import get_project_by_id +from app.crud.project import get_project_by_id from app.core.util import now from app.core.security import api_key_manager From 543d07ea32b45b6a109ac2dd3f6426b557431df6 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Fri, 10 Oct 2025 18:04:24 +0530 Subject: [PATCH 25/37] pre commit --- backend/app/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 9d14cd92..70fb3949 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -110,4 +110,4 @@ UserUpdateMe, UsersPublic, UpdatePassword, -) \ No newline at end of file +) From 32ffe9fc0b13ffe74fd80a5acf2a375d76c98fea Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Fri, 10 Oct 2025 18:16:10 +0530 Subject: [PATCH 26/37] fix migration --- ...able.py => a06c34a6d730_refactor_api_key_table.py} | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) rename backend/app/alembic/versions/{d209cddac1fa_refactor_api_key_table.py => a06c34a6d730_refactor_api_key_table.py} (93%) diff --git a/backend/app/alembic/versions/d209cddac1fa_refactor_api_key_table.py b/backend/app/alembic/versions/a06c34a6d730_refactor_api_key_table.py similarity index 93% rename from backend/app/alembic/versions/d209cddac1fa_refactor_api_key_table.py rename to backend/app/alembic/versions/a06c34a6d730_refactor_api_key_table.py index 20d66172..0154ad7a 100644 --- a/backend/app/alembic/versions/d209cddac1fa_refactor_api_key_table.py +++ b/backend/app/alembic/versions/a06c34a6d730_refactor_api_key_table.py @@ -1,21 +1,20 @@ """Refactor API key table -Revision ID: d209cddac1fa -Revises: c6fb6d0b5897 -Create Date: 2025-10-03 11:35:13.012517 +Revision ID: a06c34a6d730 +Revises: b30727137e65 +Create Date: 2025-10-10 18:14:46.423720 """ from alembic import op import sqlalchemy as sa import sqlmodel.sql.sqltypes from sqlalchemy.orm import Session - from app.alembic.migrate_api_key import migrate_api_keys, verify_migration # revision identifiers, used by Alembic. -revision = "d209cddac1fa" -down_revision = "c6fb6d0b5897" +revision = 'a06c34a6d730' +down_revision = 'b30727137e65' branch_labels = None depends_on = None From 1fa1c3201ff44e89f054be99766437b71518e200 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Fri, 10 Oct 2025 18:17:42 +0530 Subject: [PATCH 27/37] precommit --- .../alembic/versions/a06c34a6d730_refactor_api_key_table.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/app/alembic/versions/a06c34a6d730_refactor_api_key_table.py b/backend/app/alembic/versions/a06c34a6d730_refactor_api_key_table.py index 0154ad7a..9dc2c13c 100644 --- a/backend/app/alembic/versions/a06c34a6d730_refactor_api_key_table.py +++ b/backend/app/alembic/versions/a06c34a6d730_refactor_api_key_table.py @@ -13,8 +13,8 @@ # revision identifiers, used by Alembic. -revision = 'a06c34a6d730' -down_revision = 'b30727137e65' +revision = "a06c34a6d730" +down_revision = "b30727137e65" branch_labels = None depends_on = None From f51bbcd1e647c65837b1166d94bdb4429f77734c Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Thu, 16 Oct 2025 12:49:49 +0530 Subject: [PATCH 28/37] Refactor API key handling and improve documentation - Updated downgrade function to clarify snapshot restoration. - Changed conditional logic in get_current_user function for clarity. - Enhanced create_api_key_route response with security message about raw API key. - Updated APIKeyManager docstring - removed unnecessary comments --- .../a06c34a6d730_refactor_api_key_table.py | 1 + backend/app/api/deps.py | 2 +- backend/app/api/routes/api_keys.py | 10 +++++++--- backend/app/core/security.py | 14 ++++++++++++-- backend/app/tests/core/test_security.py | 10 ---------- 5 files changed, 21 insertions(+), 16 deletions(-) diff --git a/backend/app/alembic/versions/a06c34a6d730_refactor_api_key_table.py b/backend/app/alembic/versions/a06c34a6d730_refactor_api_key_table.py index 9dc2c13c..79c5893f 100644 --- a/backend/app/alembic/versions/a06c34a6d730_refactor_api_key_table.py +++ b/backend/app/alembic/versions/a06c34a6d730_refactor_api_key_table.py @@ -61,4 +61,5 @@ def upgrade(): def downgrade(): + # instead of downgrade, will take a db snapshot and restore from that if needed pass diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index 1fc5d619..c05bd413 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -60,7 +60,7 @@ def get_current_user( return user # Return only User object - if token: + elif token: try: payload = jwt.decode( token, settings.SECRET_KEY, algorithms=[security.ALGORITHM] diff --git a/backend/app/api/routes/api_keys.py b/backend/app/api/routes/api_keys.py index 7bbbfd10..ca6fbc4c 100644 --- a/backend/app/api/routes/api_keys.py +++ b/backend/app/api/routes/api_keys.py @@ -35,8 +35,12 @@ def create_api_key_route( ) api_key = APIKeyCreateResponse(**api_key.model_dump(), key=raw_key) - - return APIResponse.success_response(api_key) + return APIResponse.success_response( + data=api_key, + metadata={ + "message": "The raw API key is returned only once during creation. Store it securely as it cannot be retrieved again." + }, + ) @router.get( @@ -53,7 +57,7 @@ def list_api_keys_route( """ List all API keys for the current project. - Returns masked keys for security - the full key is only shown during creation. + Returns key prefix for security - the full key is only shown during creation. Supports pagination via skip and limit parameters. """ crud = APIKeyCrud(session, current_user.project_id) diff --git a/backend/app/core/security.py b/backend/app/core/security.py index ddd24a16..e12ffdcd 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -190,8 +190,18 @@ def decrypt_credentials(encrypted_credentials: str) -> dict: class APIKeyManager: """ - Handles API key generation and verification using secure hashing. - Supports Backwards compatibility with old key format. + Handles secure API key generation and verification. + + Overview: + - **Old Format (Legacy)**: 43 chars after "ApiKey ", with 12-char prefix and 31-char secret. + - **New Format (Current)**: 65 chars after "ApiKey ", with 22-char prefix and 43-char secret. + - Generates cryptographically secure API keys with fixed lengths, + storing only the hashed secret using bcrypt while keeping the prefix in plaintext for quick lookup. + Raw keys are displayed only once during creation for security. + The system automatically verifies both old and new key formats to ensure backward compatibility. + + Compatibility: + Both old and new formats are supported automatically during verification. """ # Configuration constants diff --git a/backend/app/tests/core/test_security.py b/backend/app/tests/core/test_security.py index 88e66bea..f7974bab 100644 --- a/backend/app/tests/core/test_security.py +++ b/backend/app/tests/core/test_security.py @@ -172,10 +172,8 @@ def test_generate_hash_is_bcrypt(self): def test_extract_key_parts_new_format(self): """Test extracting key parts from new format (65 chars).""" - # Generate a test key with known format raw_key, expected_prefix, _ = APIKeyManager.generate() - # Extract parts result = APIKeyManager._extract_key_parts(raw_key) assert result is not None @@ -206,7 +204,6 @@ def test_extract_key_parts_invalid_prefix(self): def test_extract_key_parts_invalid_length(self): """Test that invalid length returns None.""" - # Too short invalid_key = f"{APIKeyManager.PREFIX_NAME}tooshort" result = APIKeyManager._extract_key_parts(invalid_key) @@ -217,7 +214,6 @@ def test_verify_valid_key(self, db: Session): """Test verifying a valid API key.""" api_key = create_test_api_key(db) - # Verify the key auth_context = APIKeyManager.verify(db, api_key.key) user = db.get(User, api_key.user_id) @@ -238,7 +234,6 @@ def test_verify_invalid_key(self, db: Session): # Generate a key but don't store it raw_key, _, _ = APIKeyManager.generate() - # Verify should return None auth_context = APIKeyManager.verify(db, raw_key) assert auth_context is None @@ -257,16 +252,13 @@ def test_verify_wrong_secret(self, db: Session): def test_verify_deleted_key(self, db: Session): """Test that deleted API keys cannot be verified.""" - # Create test API key api_key_response = create_test_api_key(db) raw_key = api_key_response.key - # Mark the API key as deleted api_key = db.get(APIKey, api_key_response.id) api_key.is_deleted = True db.commit() - # Verify should return None for deleted key auth_context = APIKeyManager.verify(db, raw_key) assert auth_context is None @@ -298,10 +290,8 @@ def test_key_length_constants(self): def test_generate_creates_verifiable_key(self, db: Session): """Integration test: generated key can be verified.""" - # Create test API key api_key_response = create_test_api_key(db) - # Verify the key works auth_context = APIKeyManager.verify(db, api_key_response.key) assert auth_context is not None From 417c15c5c56c25328edda7108b1f5cceebbb837a Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Thu, 16 Oct 2025 13:03:42 +0530 Subject: [PATCH 29/37] Update APIKeyManager docstring and enhance seed data comment for clarity --- backend/app/core/security.py | 2 +- backend/app/seed_data/seed_data.json | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/backend/app/core/security.py b/backend/app/core/security.py index e12ffdcd..1ba1ad3f 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -285,7 +285,7 @@ def verify(cls, session: Session, raw_key: str) -> AuthContext | None: raw_key: The raw API key to verify Returns: - Tuple of (APIKey, User, Organization, Project) if valid, None otherwise + AuthContext if valid, None otherwise """ try: key_parts = cls._extract_key_parts(raw_key) diff --git a/backend/app/seed_data/seed_data.json b/backend/app/seed_data/seed_data.json index 19602dc3..6eb4860a 100644 --- a/backend/app/seed_data/seed_data.json +++ b/backend/app/seed_data/seed_data.json @@ -1,4 +1,5 @@ { + "_comment":"This data will be used in testing also, modifying this will affect tests. Please ensure to update tests/utils/auth.py if you change this data.", "organization": [ { "name": "Project Tech4dev", From 8ce8a1293a2e32553f377f3c76bfd87c4859a934 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Thu, 16 Oct 2025 13:10:32 +0530 Subject: [PATCH 30/37] fix migration head --- ...table.py => e7c68e43ce6f_refactor_api_key_table.py} | 10 +++++----- backend/app/tests/utils/test_data.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) rename backend/app/alembic/versions/{a06c34a6d730_refactor_api_key_table.py => e7c68e43ce6f_refactor_api_key_table.py} (93%) diff --git a/backend/app/alembic/versions/a06c34a6d730_refactor_api_key_table.py b/backend/app/alembic/versions/e7c68e43ce6f_refactor_api_key_table.py similarity index 93% rename from backend/app/alembic/versions/a06c34a6d730_refactor_api_key_table.py rename to backend/app/alembic/versions/e7c68e43ce6f_refactor_api_key_table.py index 79c5893f..c5037d45 100644 --- a/backend/app/alembic/versions/a06c34a6d730_refactor_api_key_table.py +++ b/backend/app/alembic/versions/e7c68e43ce6f_refactor_api_key_table.py @@ -1,8 +1,8 @@ """Refactor API key table -Revision ID: a06c34a6d730 -Revises: b30727137e65 -Create Date: 2025-10-10 18:14:46.423720 +Revision ID: e7c68e43ce6f +Revises: 27c271ab6dd0 +Create Date: 2025-10-16 13:06:51.777671 """ from alembic import op @@ -13,8 +13,8 @@ # revision identifiers, used by Alembic. -revision = "a06c34a6d730" -down_revision = "b30727137e65" +revision = 'e7c68e43ce6f' +down_revision = '27c271ab6dd0' branch_labels = None depends_on = None diff --git a/backend/app/tests/utils/test_data.py b/backend/app/tests/utils/test_data.py index 8ff76b0f..c560bbca 100644 --- a/backend/app/tests/utils/test_data.py +++ b/backend/app/tests/utils/test_data.py @@ -103,7 +103,7 @@ def create_test_api_key( api_key_crud = APIKeyCrud(session=db, project_id=project_id) raw_key, api_key = api_key_crud.create(user_id=user_id, project_id=project_id) - return APIKeyCreateResponse(key=raw_key, **api_key.dict()) + return APIKeyCreateResponse(key=raw_key, **api_key.model_dump()) def create_test_credential(db: Session) -> tuple[list[Credential], Project]: From 01f705ad6c9a7f4874f112d79c5182873db30b3d Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Thu, 16 Oct 2025 13:16:18 +0530 Subject: [PATCH 31/37] precomit --- .../alembic/versions/e7c68e43ce6f_refactor_api_key_table.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/app/alembic/versions/e7c68e43ce6f_refactor_api_key_table.py b/backend/app/alembic/versions/e7c68e43ce6f_refactor_api_key_table.py index c5037d45..262ad3e6 100644 --- a/backend/app/alembic/versions/e7c68e43ce6f_refactor_api_key_table.py +++ b/backend/app/alembic/versions/e7c68e43ce6f_refactor_api_key_table.py @@ -13,8 +13,8 @@ # revision identifiers, used by Alembic. -revision = 'e7c68e43ce6f' -down_revision = '27c271ab6dd0' +revision = "e7c68e43ce6f" +down_revision = "27c271ab6dd0" branch_labels = None depends_on = None From 04126d0a0f173bf2af4d7f2a33593d382e48f577 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Thu, 16 Oct 2025 13:23:35 +0530 Subject: [PATCH 32/37] Refactor API key migration to use context manager for session handling and improve code clarity --- .../versions/e7c68e43ce6f_refactor_api_key_table.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/backend/app/alembic/versions/e7c68e43ce6f_refactor_api_key_table.py b/backend/app/alembic/versions/e7c68e43ce6f_refactor_api_key_table.py index 262ad3e6..2c0778e2 100644 --- a/backend/app/alembic/versions/e7c68e43ce6f_refactor_api_key_table.py +++ b/backend/app/alembic/versions/e7c68e43ce6f_refactor_api_key_table.py @@ -36,12 +36,14 @@ def upgrade(): # Step 3: Migrate existing encrypted keys to the new hashed format and generate UUIDs bind = op.get_bind() - session = Session(bind=bind) - migrate_api_keys(session, generate_uuid=True) + with Session(bind=bind) as session: + migrate_api_keys(session, generate_uuid=True) - # Step 4: Verify migration was successful - if not verify_migration(session): - raise Exception("API key migration verification failed. Please check the logs.") + # Step 4: Verify migration was successful + if not verify_migration(session): + raise Exception("API key migration verification failed. Please check the logs.") + + session.flush() # Step 5: Make the columns non-nullable after migration op.alter_column("apikey", "key_prefix", nullable=False) From 890bbd59e0f5793c90c5678ee0ee8002a6bc8e77 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Thu, 16 Oct 2025 13:25:33 +0530 Subject: [PATCH 33/37] precommit --- .../alembic/versions/e7c68e43ce6f_refactor_api_key_table.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/backend/app/alembic/versions/e7c68e43ce6f_refactor_api_key_table.py b/backend/app/alembic/versions/e7c68e43ce6f_refactor_api_key_table.py index 2c0778e2..42d5080c 100644 --- a/backend/app/alembic/versions/e7c68e43ce6f_refactor_api_key_table.py +++ b/backend/app/alembic/versions/e7c68e43ce6f_refactor_api_key_table.py @@ -41,8 +41,10 @@ def upgrade(): # Step 4: Verify migration was successful if not verify_migration(session): - raise Exception("API key migration verification failed. Please check the logs.") - + raise Exception( + "API key migration verification failed. Please check the logs." + ) + session.flush() # Step 5: Make the columns non-nullable after migration From fb7fe60e0dab0a30257a26b0497c937e7ee29e43 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Thu, 16 Oct 2025 18:37:08 +0530 Subject: [PATCH 34/37] remove id's from authcontext --- backend/app/api/deps.py | 12 ++++++------ backend/app/api/permissions.py | 4 ++-- backend/app/api/routes/api_keys.py | 4 ++-- backend/app/core/security.py | 3 --- backend/app/models/auth.py | 4 ---- backend/app/tests/api/test_permissions.py | 2 -- backend/app/tests/core/test_security.py | 8 ++++---- 7 files changed, 14 insertions(+), 23 deletions(-) diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index 5ff88880..a9e4e46e 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -49,7 +49,7 @@ def get_current_user( if not api_key_record: raise HTTPException(status_code=401, detail="Invalid API Key") - user = session.get(User, api_key_record.user_id) + user = session.get(User, api_key_record.user.id) if not user: raise HTTPException( status_code=404, detail="User linked to API Key not found" @@ -93,8 +93,8 @@ def get_current_user_org( if api_key: api_key_record = api_key_manager.verify(session, api_key) if api_key_record: - validate_organization(session, api_key_record.organization_id) - organization_id = api_key_record.organization_id + validate_organization(session, api_key_record.organization.id) + organization_id = api_key_record.organization.id return UserOrganization( **current_user.model_dump(), organization_id=organization_id @@ -114,9 +114,9 @@ def get_current_user_org_project( if api_key: api_key_record = api_key_manager.verify(session, api_key) if api_key_record: - validate_organization(session, api_key_record.organization_id) - organization_id = api_key_record.organization_id - project_id = api_key_record.project_id + validate_organization(session, api_key_record.organization.id) + organization_id = api_key_record.organization.id + project_id = api_key_record.project.id else: raise HTTPException(status_code=401, detail="Invalid API Key") diff --git a/backend/app/api/permissions.py b/backend/app/api/permissions.py index b5a99c52..4142b7a4 100644 --- a/backend/app/api/permissions.py +++ b/backend/app/api/permissions.py @@ -35,9 +35,9 @@ def has_permission( case Permission.SUPERUSER: return auth_context.user.is_superuser case Permission.REQUIRE_ORGANIZATION: - return auth_context.organization_id is not None + return auth_context.organization is not None case Permission.REQUIRE_PROJECT: - return auth_context.project_id is not None + return auth_context.project is not None case _: return False diff --git a/backend/app/api/routes/api_keys.py b/backend/app/api/routes/api_keys.py index ca6fbc4c..d1821a35 100644 --- a/backend/app/api/routes/api_keys.py +++ b/backend/app/api/routes/api_keys.py @@ -60,7 +60,7 @@ def list_api_keys_route( Returns key prefix for security - the full key is only shown during creation. Supports pagination via skip and limit parameters. """ - crud = APIKeyCrud(session, current_user.project_id) + crud = APIKeyCrud(session, current_user.project.id) api_keys = crud.read_all(skip=skip, limit=limit) return APIResponse.success_response(api_keys) @@ -79,7 +79,7 @@ def delete_api_key_route( """ Delete an API key by its ID. """ - api_key_crud = APIKeyCrud(session=session, project_id=current_user.project_id) + api_key_crud = APIKeyCrud(session=session, project_id=current_user.project.id) api_key_crud.delete(key_id=key_id) return APIResponse.success_response(Message(message="API Key deleted successfully")) diff --git a/backend/app/core/security.py b/backend/app/core/security.py index 1ba1ad3f..8807d705 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -315,9 +315,6 @@ def verify(cls, session: Session, raw_key: str) -> AuthContext | None: return None api_key_record, user, organization, project = result auth_context = AuthContext( - user_id=user.id, - project_id=project.id, - organization_id=organization.id, user=user, project=project, organization=organization, diff --git a/backend/app/models/auth.py b/backend/app/models/auth.py index db40fa9d..adb93aeb 100644 --- a/backend/app/models/auth.py +++ b/backend/app/models/auth.py @@ -16,10 +16,6 @@ class TokenPayload(SQLModel): class AuthContext(SQLModel): - user_id: int - project_id: int | None = None - organization_id: int | None = None - user: User organization: Organization | None = None project: Project | None = None diff --git a/backend/app/tests/api/test_permissions.py b/backend/app/tests/api/test_permissions.py index 1d427d42..2c9092ac 100644 --- a/backend/app/tests/api/test_permissions.py +++ b/backend/app/tests/api/test_permissions.py @@ -60,7 +60,6 @@ def test_require_organization_permission_without_organization(self, db: Session) session=db, token=None, api_key=api_key_response.key ) - auth_context.organization_id = None auth_context.organization = None result = has_permission(auth_context, Permission.REQUIRE_ORGANIZATION, db) @@ -87,7 +86,6 @@ def test_require_project_permission_without_project(self, db: Session): session=db, token=None, api_key=api_key_response.key ) - auth_context.project_id = None auth_context.project = None result = has_permission(auth_context, Permission.REQUIRE_PROJECT, db) diff --git a/backend/app/tests/core/test_security.py b/backend/app/tests/core/test_security.py index f7974bab..59101375 100644 --- a/backend/app/tests/core/test_security.py +++ b/backend/app/tests/core/test_security.py @@ -222,9 +222,9 @@ def test_verify_valid_key(self, db: Session): assert auth_context is not None assert isinstance(auth_context, AuthContext) - assert auth_context.user_id == api_key.user_id - assert auth_context.organization_id == api_key.organization_id - assert auth_context.project_id == api_key.project_id + assert auth_context.user.id == api_key.user_id + assert auth_context.organization.id == api_key.organization_id + assert auth_context.project.id == api_key.project_id assert auth_context.user == user assert auth_context.organization == organization assert auth_context.project == project @@ -295,4 +295,4 @@ def test_generate_creates_verifiable_key(self, db: Session): auth_context = APIKeyManager.verify(db, api_key_response.key) assert auth_context is not None - assert auth_context.user_id == api_key_response.user_id + assert auth_context.user.id == api_key_response.user_id From 645b5107c28553a8d6ca96bab71a15ce4d1deac8 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Thu, 16 Oct 2025 18:54:22 +0530 Subject: [PATCH 35/37] Remove user_id from AuthContext instantiation in get_auth_context function --- backend/app/api/deps.py | 1 - 1 file changed, 1 deletion(-) diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index a9e4e46e..b975c42c 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -192,7 +192,6 @@ def get_auth_context( raise HTTPException(status_code=403, detail="Inactive user") auth_context = AuthContext( - user_id=user.id, user=user, ) return auth_context From e91fd5c7052bb3a1090aba1be2dd89b28c1b8a03 Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Fri, 17 Oct 2025 08:26:01 +0530 Subject: [PATCH 36/37] Update user authentication to check for active status before returning user object --- backend/app/api/deps.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index b975c42c..6d765375 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -49,13 +49,12 @@ def get_current_user( if not api_key_record: raise HTTPException(status_code=401, detail="Invalid API Key") - user = session.get(User, api_key_record.user.id) - if not user: + if not api_key_record.user.is_active: raise HTTPException( - status_code=404, detail="User linked to API Key not found" + status_code=403, detail="Inactive user" ) - return user # Return only User object + return api_key_record.user # Return only User object elif token: try: From 34632a109e2893ddfcd90ea2617264b2f2cedefe Mon Sep 17 00:00:00 2001 From: Aviraj <100823015+avirajsingh7@users.noreply.github.com> Date: Fri, 17 Oct 2025 08:26:32 +0530 Subject: [PATCH 37/37] precommit --- backend/app/api/deps.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index 6d765375..59678d2f 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -50,9 +50,7 @@ def get_current_user( raise HTTPException(status_code=401, detail="Invalid API Key") if not api_key_record.user.is_active: - raise HTTPException( - status_code=403, detail="Inactive user" - ) + raise HTTPException(status_code=403, detail="Inactive user") return api_key_record.user # Return only User object