In [None]:
#| default_exp utils_oauth

# OAuth Authentication

> Multi-user tenant authentication with Google OAuth, CSRF protection, and automatic tenant provisioning.

## Architecture

**Authentication Flow:**
1. User clicks "Login with Google" → generates CSRF state token
2. Google authenticates user → redirects back with auth code
3. Verify CSRF state matches → exchange code for user info
4. Create/get GlobalUser in host DB → check for existing Membership
5. If no membership → auto-provision new tenant with user as owner
6. Store session data (user_id, tenant_id, role) → redirect to dashboard

**Tenant Model:**
- Many users → One tenant (many-to-one relationship)
- Each user belongs to exactly ONE tenant
- Each tenant can have MANY users
- New users auto-create their own tenant (future: payment screen)

**Security:**
- CSRF protection via state parameter prevents session hijacking
- Membership validation ensures users only access their tenant
- Audit logging tracks all authentication events

**Token Expiry:**
- Google OAuth tokens expire after 1 hour
- Current: users must re-login after expiry
- Future: implement refresh token flow for seamless re-authentication

In [None]:
#| export

from fastsql import *
from fastcore.utils import *
from fh_saas.db_host import (
    timestamp, gen_id, 
    GlobalUser, TenantCatalog, Membership, HostAuditLog,
    HostDatabase
)
from fh_saas.db_tenant import (
    get_or_create_tenant_db, 
    init_tenant_core_schema, 
    TenantUser
)
from fh_saas.utils_log import DatabaseHandler, TenantContext
from fasthtml.oauth import GoogleAppClient, redir_url
from starlette.responses import RedirectResponse
import os
import uuid
import json
from dotenv import load_dotenv

load_dotenv()

## OAuth Client Setup

In [None]:
#| export

def get_google_oauth_client():
    """
    Initialize Google OAuth client with credentials from environment.
    
    Returns:
        GoogleAppClient: Configured OAuth client
    
    Environment Variables Required:
        GOOGLE_CLIENT_ID: OAuth client ID from Google Cloud Console
        GOOGLE_CLIENT_SECRET: OAuth client secret
    
    Note:
        Redirect URI must be registered in Google Cloud Console:
        - Development: http://localhost:8000/auth/callback
        - Production: https://yourdomain.com/auth/callback
    """
    client_id = os.getenv('GOOGLE_CLIENT_ID')
    client_secret = os.getenv('GOOGLE_CLIENT_SECRET')
    
    if not client_id or not client_secret:
        raise ValueError(
            "Missing OAuth credentials. Set GOOGLE_CLIENT_ID and "
            "GOOGLE_CLIENT_SECRET in .env file"
        )
    
    return GoogleAppClient(
        client_id=client_id,
        client_secret=client_secret
    )

## CSRF Protection

### Why CSRF Protection?

**Attack Scenario Without CSRF:**
1. Attacker initiates OAuth flow, gets auth code
2. Attacker sends victim link: `yourapp.com/auth/callback?code=ATTACKER_CODE`
3. Victim clicks link → your app creates session for attacker's account
4. Victim enters sensitive data thinking they're in their own account
5. Attacker sees all victim's data in attacker's account

**Solution: State Token**
- Generate random UUID when user clicks login
- Store in session, pass to Google via state parameter
- Google includes state in callback URL
- Verify state matches before creating session
- If mismatch → reject (prevents attacker code injection)

In [None]:
#| export

def generate_oauth_state():
    """
    Generate cryptographically secure random state token for CSRF protection.
    
    Returns:
        str: Random UUID hex string (32 characters)
    
    Usage:
        state = generate_oauth_state()
        session['oauth_state'] = state
        login_link = client.login_link(state=state)
    """
    return uuid.uuid4().hex


def verify_oauth_state(session: dict, callback_state: str):
    """
    Verify OAuth callback state matches stored session state (CSRF protection).
    
    Args:
        session: Starlette session dict containing 'oauth_state'
        callback_state: State parameter from OAuth callback URL
    
    Raises:
        ValueError: If state missing from session or mismatch detected
    
    Security:
        Always call this BEFORE exchanging auth code for tokens.
        Prevents CSRF attacks where attacker injects their auth code.
    """
    stored_state = session.get('oauth_state')
    
    if not stored_state:
        raise ValueError(
            "CSRF validation failed: No state found in session. "
            "This may indicate a direct callback access attempt."
        )
    
    if stored_state != callback_state:
        raise ValueError(
            "CSRF validation failed: State mismatch. "
            "Possible session hijacking attempt detected."
        )
    
    # Clear state after successful verification (one-time use)
    session.pop('oauth_state', None)

## User Management

In [None]:
#| export

def create_or_get_global_user(
    host_db: HostDatabase,
    oauth_id: str, 
    email: str, 
    oauth_info: dict = None,
    log_handler: DatabaseHandler = None
):
    """
    Create or retrieve GlobalUser from host database.
    
    Args:
        host_db: HostDatabase instance (singleton)
        oauth_id: Unique identifier from OAuth provider (Google 'sub' field)
        email: User's email address
        oauth_info: Full OAuth user info dict (for future use)
        log_handler: Optional DatabaseHandler for logging
    
    Returns:
        GlobalUser: User record from host database
    
    Note:
        - Updates last_login timestamp on every call
        - Google tokens expire after 1 hour (user must re-login)
        - Caller must commit transaction after calling this function
        - Future: Store refresh_token for auto-refresh
    """
    try:
        # Check if user exists by oauth_id
        host_db.rollback()
        all_users = host_db.global_users()
        existing = [u for u in all_users if u.oauth_id == oauth_id]
        
        if existing:
            # Update last login timestamp
            user = existing[0]
            user.last_login = timestamp()
            host_db.global_users.update(user)
            
            if log_handler:
                log_handler.write_log(
                    level='INFO',
                    message=f'User login: {email}',
                    operation='user_login',
                    status='success',
                    user_id=user.id,
                    email=email
                )
            
            return user
        
        # Create new user
        new_user = GlobalUser(
            id=gen_id(),
            email=email,
            oauth_id=oauth_id,
            created_at=timestamp(),
            last_login=timestamp()
        )
        host_db.global_users.insert(new_user)
        
        if log_handler:
            log_handler.write_log(
                level='INFO',
                message=f'New user created: {email}',
                operation='user_created',
                status='success',
                user_id=new_user.id,
                email=email
            )
        
        return new_user
        
    except Exception as e:
        host_db.rollback()
        if log_handler:
            log_handler.write_log(
                level='ERROR',
                message=f'Failed to create/get user {email}: {str(e)}',
                operation='user_auth',
                status='error',
                error=str(e)
            )
        raise


def get_user_membership(
    host_db: HostDatabase,
    user_id: str,
    log_handler: DatabaseHandler = None
):
    """
    Get single active membership for user.
    
    Args:
        host_db: HostDatabase instance (singleton)
        user_id: GlobalUser.id from host database
        log_handler: Optional DatabaseHandler for logging
    
    Returns:
        Membership | None: Active membership record or None if not found
    
    Note:
        Current design: one user = one tenant (enforced in code)
        Returns first active membership found
        Read-only operation: no transaction management needed
        Future: May add multi-tenant support with tenant selector
    """
    host_db.rollback()  # Clear any stale transaction state
    all_memberships = host_db.memberships()
    active = [
        m for m in all_memberships 
        if m.user_id == user_id and m.is_active
    ]
    
    return active[0] if active else None


def verify_membership(
    host_db: HostDatabase,
    user_id: str, 
    tenant_id: str,
    log_handler: DatabaseHandler = None
) -> bool:
    """
    Verify user has active membership for specific tenant.
    
    Args:
        host_db: HostDatabase instance (singleton)
        user_id: GlobalUser.id from session
        tenant_id: TenantCatalog.id from session
        log_handler: Optional DatabaseHandler for logging
    
    Returns:
        bool: True if active membership exists, False otherwise
    
    Security:
        MUST be called before granting access to tenant database.
        Prevents cross-tenant data access by validating membership.
        Read-only operation: no transaction management needed
    """
    host_db.rollback()  # Clear any stale transaction state
    all_memberships = host_db.memberships()
    
    valid = [
        m for m in all_memberships
        if m.user_id == user_id 
        and m.tenant_id == tenant_id 
        and m.is_active
    ]
    
    return len(valid) > 0

## Membership & Tenant Access

## Auto-Provisioning (New User Onboarding)

When a new user logs in for the first time:
1. Create physical tenant database (PostgreSQL)
2. Register tenant in host catalog
3. Create membership linking user to tenant
4. Create TenantUser profile in tenant database
5. Initialize core tenant schema (users, permissions, settings)

**Future Enhancement:** Insert payment screen before step 1

In [None]:
#| export

def provision_new_user(
    host_db: HostDatabase,
    global_user: GlobalUser,
    log_handler: DatabaseHandler = None
) -> str:
    """
    Auto-provision new tenant for first-time user with transaction management.
    
    Args:
        host_db: HostDatabase instance (singleton)
        global_user: GlobalUser record from host database
        log_handler: Optional DatabaseHandler for logging
    
    Returns:
        str: Tenant ID of newly created tenant
    
    Creates:
        - Physical tenant database (PostgreSQL or SQLite)
        - TenantCatalog entry in host database
        - Membership linking user to tenant (role='owner')
        - TenantUser profile in tenant database (local_role='admin')
        - Core tenant schema (users, permissions, settings tables)
        - Audit log entry in host database
    
    Transaction Handling:
        - Commits all host database changes on success
        - Rolls back host database on any failure
        - Logs all operations for audit trail
    
    Future:
        - Add payment screen before database creation
        - Send welcome email with onboarding steps
    """
    tenant_id = gen_id()
    username = global_user.email.split('@')[0]
    tenant_name = f"{username}'s Workspace"
    
    try:
        if log_handler:
            log_handler.write_log(
                level='INFO',
                message=f'Starting tenant provisioning for {global_user.email}',
                operation='tenant_provision_start',
                status='info',
                tenant_id=tenant_id,
                user_email=global_user.email,
                user_id=global_user.id
            )
        
        # 1. Create physical tenant database and register in catalog
        tenant_db = get_or_create_tenant_db(tenant_id, tenant_name)
        
        # 2. Initialize core tenant schema
        core_tables = init_tenant_core_schema(tenant_db)
        
        # 3. Create TenantUser profile in tenant database
        tenant_user = TenantUser(
            id=global_user.id,  # MUST match GlobalUser.id
            display_name=username,
            local_role='admin',  # First user is admin
            created_at=timestamp()
        )
        core_tables['tenant_users'].insert(tenant_user)
        
        # 4. Create membership in host database
        membership = Membership(
            id=gen_id(),
            user_id=global_user.id,
            tenant_id=tenant_id,
            profile_id=global_user.id,  # Links to TenantUser.id
            role='owner',  # First user owns the tenant
            created_at=timestamp()
        )
        host_db.memberships.insert(membership)
        
        # 5. Log provisioning event in audit log
        audit_log = HostAuditLog(
            id=gen_id(),
            actor_user_id=global_user.id,
            event_type='tenant_provisioned',
            target_id=tenant_id,
            details=json.dumps({
                'tenant_name': tenant_name,
                'plan_tier': 'free',
                'user_email': global_user.email
            }),
            created_at=timestamp()
        )
        host_db.audit_logs.insert(audit_log)
        
        # Commit all host database changes
        host_db.commit()
        
        if log_handler:
            log_handler.write_log(
                level='INFO',
                message=f'Tenant provisioned successfully: {tenant_name}',
                operation='tenant_provision_complete',
                status='success',
                tenant_id=tenant_id,
                tenant_name=tenant_name,
                user_email=global_user.email,
                user_id=global_user.id
            )
        
        return tenant_id
        
    except Exception as e:
        # Rollback all host database changes
        host_db.rollback()
        
        if log_handler:
            log_handler.write_log(
                level='ERROR',
                message=f'Tenant provisioning failed for {global_user.email}: {str(e)}',
                operation='tenant_provision_failed',
                status='error',
                tenant_id=tenant_id,
                user_email=global_user.email,
                error=str(e),
                user_id=global_user.id
            )
        
        raise Exception(f"Failed to provision tenant for {global_user.email}: {str(e)}") from e

## Session Management

In [None]:
#| export

def create_user_session(session: dict, global_user: GlobalUser, membership: Membership):
    """
    Create authenticated session after successful OAuth login.
    
    Args:
        session: Starlette session dict to populate
        global_user: Authenticated user from host database
        membership: User's tenant membership record
    
    Session Keys Set:
        user_id: GlobalUser.id (for host database queries)
        email: User's email address (for display)
        tenant_id: TenantCatalog.id (determines database connection)
        tenant_role: Membership.role (owner/admin/member)
        is_sys_admin: GlobalUser.is_sys_admin (host access flag)
        login_at: ISO timestamp of session creation
    
    Security:
        Session is cryptographically signed by Starlette.
        User cannot tamper with session values.
    """
    session['user_id'] = global_user.id
    session['email'] = global_user.email
    session['tenant_id'] = membership.tenant_id
    session['tenant_role'] = membership.role
    session['is_sys_admin'] = global_user.is_sys_admin
    session['login_at'] = timestamp()


def get_current_user(session: dict) -> dict | None:
    """
    Extract current user info from session.
    
    Args:
        session: Starlette session dict
    
    Returns:
        dict | None: User info dict or None if not authenticated
    
    Usage:
        user = get_current_user(session)
        if not user:
            return RedirectResponse('/login', status_code=303)
    """
    if 'user_id' not in session:
        return None
    
    return {
        'user_id': session.get('user_id'),
        'email': session.get('email'),
        'tenant_id': session.get('tenant_id'),
        'tenant_role': session.get('tenant_role'),
        'is_sys_admin': session.get('is_sys_admin', False)
    }


def clear_session(session: dict):
    """
    Clear all session data (logout).
    
    Args:
        session: Starlette session dict to clear
    
    Note:
        Removes all keys, not just auth-related ones.
        User will need to re-authenticate on next request.
    """
    session.clear()

## Route Helpers & Authorization

In [None]:
#| export

def route_user_after_login(global_user: GlobalUser, membership: Membership = None) -> str:
    """
    Determine redirect URL based on user type and membership.
    
    Args:
        global_user: Authenticated user from host database
        membership: User's tenant membership (None for sys admins)
    
    Returns:
        str: Redirect URL path
    
    Routing Logic:
        - System admin (is_sys_admin=True) → /admin/dashboard
        - Tenant user with membership → /dashboard
        - User without membership → Error (should have been provisioned)
    """
    if global_user.is_sys_admin:
        return '/admin/dashboard'
    
    if membership:
        return '/dashboard'
    
    # Should never reach here (provision_new_user should have created membership)
    raise ValueError(
        f"User {global_user.email} has no membership. "
        "This indicates a provisioning failure."
    )


def require_tenant_access(session: dict):
    """
    Validate user has access to tenant from session.
    
    Args:
        session: Starlette session dict
    
    Returns:
        Database: Tenant database connection
    
    Raises:
        ValueError: If user not authenticated
        PermissionError: If membership invalid or inactive
    
    Security:
        CRITICAL: Call this at start of every tenant route handler.
        Verifies membership before granting database access.
        Prevents cross-tenant data access attacks.
    
    Usage:
        @app.get('/tenant/transactions')
        def get_transactions(session):
            tenant_db = require_tenant_access(session)
            # Now safe to query tenant database
            return tenant_db.t.transactions()
    """
    user = get_current_user(session)
    if not user:
        raise ValueError("Authentication required. Please login.")
    
    user_id = user['user_id']
    tenant_id = user['tenant_id']
    
    # Get host database instance
    host_db = HostDatabase.from_env()
    
    # Verify active membership (CRITICAL SECURITY CHECK)
    if not verify_membership(host_db, user_id, tenant_id):
        raise PermissionError(
            f"Access denied. User {user_id} does not have active "
            f"membership for tenant {tenant_id}."
        )
    
    # Return tenant database connection
    return get_or_create_tenant_db(tenant_id)

## OAuth Route Handlers (Pseudo-code)

These functions show how to integrate with FastHTML routes.
Actual route implementation will be in main app file.

In [None]:
#| export

def handle_login_request(request, session):
    """
    Handle /login route - generate OAuth link with CSRF protection.
    
    Usage in FastHTML:
        @app.get('/login')
        def login(request, session):
            return handle_login_request(request, session)
    
    Returns:
        str: Google OAuth authorization URL with state parameter
    """
    # Generate CSRF state token
    state = generate_oauth_state()
    session['oauth_state'] = state
    
    # Get OAuth client and generate login link
    client = get_google_oauth_client()
    redirect_uri = redir_url(request, '/auth/callback')
    login_link = client.login_link(redirect_uri=redirect_uri, state=state)
    
    return login_link


def handle_oauth_callback(code: str, state: str, request, session):
    """
    Handle /auth/callback route - complete OAuth flow with CSRF validation.
    
    Usage in FastHTML:
        @app.get('/auth/callback')
        def auth_callback(code: str, state: str, request, session):
            return handle_oauth_callback(code, state, request, session)
    
    Args:
        code: Authorization code from Google
        state: CSRF state token from Google (should match session)
        request: Starlette request object
        session: Starlette session dict
    
    Returns:
        RedirectResponse: Redirect to appropriate dashboard
    
    Security:
        1. Verify CSRF state (prevents session hijacking)
        2. Exchange code for user info
        3. Create/get GlobalUser
        4. Check membership or provision new tenant
        5. Create session
        6. Redirect based on user type
    """
    # Step 1: CSRF validation (CRITICAL - must be first)
    verify_oauth_state(session, state)
    
    # Step 2: Exchange authorization code for user info
    client = get_google_oauth_client()
    redirect_uri = redir_url(request, '/auth/callback')
    user_info = client.retr_info(code, redirect_uri)
    
    # Step 3: Create or get GlobalUser
    oauth_id = user_info[client.id_key]  # Google 'sub' field
    email = user_info.get('email', '')
    global_user = create_or_get_global_user(oauth_id, email, user_info)
    
    # Step 4: Check for existing membership
    membership = get_user_membership(global_user.id)
    
    # Step 5: Auto-provision if new user (no membership)
    if not membership and not global_user.is_sys_admin:
        tenant_id = provision_new_user(global_user)
        membership = get_user_membership(global_user.id)
    
    # Step 6: Create session (skip for sys admin - no tenant)
    if membership:
        create_user_session(session, global_user, membership)
    else:
        # System admin - minimal session
        session['user_id'] = global_user.id
        session['email'] = global_user.email
        session['is_sys_admin'] = True
        session['login_at'] = timestamp()
    
    # Step 7: Route to appropriate dashboard
    redirect_url = route_user_after_login(global_user, membership)
    return RedirectResponse(redirect_url, status_code=303)


def handle_logout(session):
    """
    Handle /logout route - clear session and redirect to login.
    
    Usage in FastHTML:
        @app.get('/logout')
        def logout(session):
            return handle_logout(session)
    
    Returns:
        RedirectResponse: Redirect to login page
    """
    clear_session(session)
    return RedirectResponse('/login', status_code=303)