diff --git a/backend/app/alembic/versions/99f4fc325617_add_organization_project_setup.py b/backend/app/alembic/versions/99f4fc325617_add_organization_project_setup.py index 9122c0f73..c8a7f61c9 100644 --- a/backend/app/alembic/versions/99f4fc325617_add_organization_project_setup.py +++ b/backend/app/alembic/versions/99f4fc325617_add_organization_project_setup.py @@ -37,8 +37,7 @@ def upgrade(): sa.Column("id", sa.Integer(), nullable=False), sa.Column("organization_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( - ["organization_id"], - ["organization.id"], + ["organization_id"], ["organization.id"], ondelete="CASCADE" ), sa.PrimaryKeyConstraint("id"), ) diff --git a/backend/app/api/main.py b/backend/app/api/main.py index 216f02dbb..08f44960b 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -11,6 +11,7 @@ threads, users, utils, + onboarding, credentials, ) from app.core.config import settings @@ -25,6 +26,7 @@ api_router.include_router(project.router) api_router.include_router(project_user.router) api_router.include_router(api_keys.router) +api_router.include_router(onboarding.router) api_router.include_router(credentials.router) diff --git a/backend/app/api/routes/onboarding.py b/backend/app/api/routes/onboarding.py new file mode 100644 index 000000000..0adcfd73b --- /dev/null +++ b/backend/app/api/routes/onboarding.py @@ -0,0 +1,120 @@ +import uuid + +from fastapi import APIRouter, HTTPException, Depends +from pydantic import BaseModel, EmailStr +from sqlmodel import Session + +from app.crud import ( + create_organization, + get_organization_by_name, + create_project, + create_user, + create_api_key, + get_api_key_by_user_org, +) +from app.models import ( + OrganizationCreate, + ProjectCreate, + UserCreate, + APIKeyPublic, + Organization, + Project, + User, + APIKey, +) +from app.core.security import get_password_hash +from app.api.deps import ( + CurrentUser, + SessionDep, + get_current_active_superuser, +) + +router = APIRouter(tags=["onboarding"]) + + +# Pydantic models for input validation +class OnboardingRequest(BaseModel): + organization_name: str + project_name: str + email: EmailStr + password: str + user_name: str + + +class OnboardingResponse(BaseModel): + organization_id: int + project_id: int + user_id: uuid.UUID + api_key: str + + +@router.post( + "/onboard", + dependencies=[Depends(get_current_active_superuser)], + response_model=OnboardingResponse, +) +def onboard_user(request: OnboardingRequest, session: SessionDep): + """ + Handles quick onboarding of a new user : Accepts Organization name, project name, email, password and user name, then gives back an API key which + will be further used for authentication. + """ + try: + existing_organization = get_organization_by_name( + session=session, name=request.organization_name + ) + if existing_organization: + organization = existing_organization + else: + org_create = OrganizationCreate(name=request.organization_name) + organization = create_organization(session=session, org_create=org_create) + + existing_project = ( + session.query(Project).filter(Project.name == request.project_name).first() + ) + if existing_project: + project = existing_project # Use the existing project + else: + project_create = ProjectCreate( + name=request.project_name, organization_id=organization.id + ) + project = create_project(session=session, project_create=project_create) + + existing_user = session.query(User).filter(User.email == request.email).first() + if existing_user: + user = existing_user + else: + user_create = UserCreate( + name=request.user_name, + email=request.email, + password=request.password, + ) + user = create_user(session=session, user_create=user_create) + + existing_key = get_api_key_by_user_org( + db=session, organization_id=organization.id, user_id=user.id + ) + + if existing_key: + raise HTTPException( + status_code=400, + detail="API key already exists for this user and organization", + ) + + api_key_public = create_api_key( + session=session, organization_id=organization.id, user_id=user.id + ) + + user.is_superuser = False + session.add(user) + session.commit() + + return OnboardingResponse( + organization_id=organization.id, + project_id=project.id, + user_id=user.id, + api_key=api_key_public.key, + ) + + except Exception as e: + session.rollback() + raise HTTPException(status_code=400, detail=str(e)) diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index 1edef5431..c19c098ac 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -6,3 +6,21 @@ ) from .document import DocumentCrud + +from .organization import ( + create_organization, + get_organization_by_id, + get_organization_by_name, + validate_organization, +) + +from .project import create_project, get_project_by_id, get_projects_by_organization + +from .api_key import ( + create_api_key, + get_api_key, + get_api_key_by_user_org, + get_api_key_by_value, + get_api_keys_by_organization, + delete_api_key, +) diff --git a/backend/app/models/organization.py b/backend/app/models/organization.py index e646ded9b..150be01d0 100644 --- a/backend/app/models/organization.py +++ b/backend/app/models/organization.py @@ -5,6 +5,8 @@ if TYPE_CHECKING: from .credentials import Credential + from .project import Project + from .api_key import APIKey # Shared properties for an Organization @@ -29,10 +31,15 @@ class Organization(OrganizationBase, table=True): id: int = Field(default=None, primary_key=True) # Relationship back to Creds - api_keys: list["APIKey"] = Relationship(back_populates="organization") + api_keys: list["APIKey"] = Relationship( + back_populates="organization", sa_relationship_kwargs={"cascade": "all, delete"} + ) creds: list["Credential"] = Relationship( back_populates="organization", sa_relationship_kwargs={"cascade": "all, delete"} ) + project: list["Project"] = Relationship( + back_populates="organization", sa_relationship_kwargs={"cascade": "all, delete"} + ) # Properties to return via API diff --git a/backend/app/models/project.py b/backend/app/models/project.py index 3c4dfd9aa..93ae534de 100644 --- a/backend/app/models/project.py +++ b/backend/app/models/project.py @@ -1,3 +1,4 @@ +from typing import Optional from sqlmodel import Field, Relationship, SQLModel @@ -29,6 +30,8 @@ class Project(ProjectBase, table=True): back_populates="project", cascade_delete=True ) + organization: Optional["Organization"] = Relationship(back_populates="project") + # Properties to return via API class ProjectPublic(ProjectBase): diff --git a/backend/app/tests/api/routes/test_onboarding.py b/backend/app/tests/api/routes/test_onboarding.py new file mode 100644 index 000000000..f98b128c8 --- /dev/null +++ b/backend/app/tests/api/routes/test_onboarding.py @@ -0,0 +1,134 @@ +import pytest +from fastapi.testclient import TestClient +from app.main import app # Assuming your FastAPI app is in app/main.py +from app.models import Organization, Project, User, APIKey +from app.crud import create_organization, create_project, create_user, create_api_key +from app.api.deps import SessionDep +from sqlalchemy import create_engine +from sqlmodel import Session, SQLModel +from app.core.config import settings +from app.tests.utils.utils import random_email, random_lower_string +from app.core.security import decrypt_api_key + +client = TestClient(app) + + +def test_onboard_user(client, db: Session, superuser_token_headers: dict[str, str]): + data = { + "organization_name": "TestOrg", + "project_name": "TestProject", + "email": random_email(), + "password": "testpassword123", + "user_name": "Test User", + } + + response = client.post( + f"{settings.API_V1_STR}/onboard", json=data, headers=superuser_token_headers + ) + + assert response.status_code == 200 + + response_data = response.json() + assert "organization_id" in response_data + assert "project_id" in response_data + assert "user_id" in response_data + assert "api_key" in response_data + + organization = ( + db.query(Organization) + .filter(Organization.name == data["organization_name"]) + .first() + ) + project = db.query(Project).filter(Project.name == data["project_name"]).first() + user = db.query(User).filter(User.email == data["email"]).first() + api_key = db.query(APIKey).filter(APIKey.user_id == user.id).first() + + assert organization is not None + assert project is not None + assert user is not None + assert api_key is not None + + plain_token = response_data["api_key"] + encrypted_stored = api_key.key + + assert decrypt_api_key(encrypted_stored) == plain_token # main check + assert encrypted_stored != plain_token + + assert user.is_superuser is False + + +def test_create_user_existing_email( + client, db: Session, superuser_token_headers: dict[str, str] +): + data = { + "organization_name": "TestOrg", + "project_name": "TestProject", + "email": random_email(), + "password": "testpassword123", + "user_name": "Test User", + } + + client.post( + f"{settings.API_V1_STR}/onboard", json=data, headers=superuser_token_headers + ) + + response = client.post( + f"{settings.API_V1_STR}/onboard", json=data, headers=superuser_token_headers + ) + + assert response.status_code == 400 + assert ( + response.json()["detail"] + == "400: API key already exists for this user and organization" + ) + + +def test_is_superuser_flag( + client, db: Session, superuser_token_headers: dict[str, str] +): + data = { + "organization_name": "TestOrg", + "project_name": "TestProject", + "email": random_email(), + "password": "testpassword123", + "user_name": "Test User", + } + + response = client.post( + f"{settings.API_V1_STR}/onboard", json=data, headers=superuser_token_headers + ) + + assert response.status_code == 200 + + response_data = response.json() + user = db.query(User).filter(User.id == response_data["user_id"]).first() + assert user is not None + assert user.is_superuser is False + + +def test_organization_and_project_creation( + client, db: Session, superuser_token_headers: dict[str, str] +): + data = { + "organization_name": "NewOrg", + "project_name": "NewProject", + "email": random_email(), + "password": "newpassword123", + "user_name": "New User", + } + + response = client.post( + f"{settings.API_V1_STR}/onboard", json=data, headers=superuser_token_headers + ) + + assert response.status_code == 200 + + organization = ( + db.query(Organization) + .filter(Organization.name == data["organization_name"]) + .first() + ) + project = db.query(Project).filter(Project.name == data["project_name"]).first() + + assert organization is not None + assert project is not None