In [None]:
# Test imports
from fh_saas.utils_oauth import (
    generate_oauth_state, verify_oauth_state,
    create_or_get_global_user, get_user_membership, verify_membership,
    provision_new_user, create_user_session, get_current_user, clear_session,
    route_user_after_login, require_tenant_access
)
from fh_saas.db_host import HostDatabase
from fh_saas.db_tenant import get_or_create_tenant_db
from sqlalchemy import text
from dotenv import load_dotenv
load_dotenv()


In [None]:
#| hide


print("üß™ Running OAuth Tests...\n")

# Initialize HostDatabase singleton
host_db = HostDatabase.from_env()

# ==========================================
# CLEANUP: Make tests idempotent
# ==========================================

def cleanup_oauth_test_data():
    """Remove test data to ensure idempotent test runs"""
    host_db.rollback()
    
    # Delete in reverse dependency order
    host_db.db.conn.execute(text(
        "DELETE FROM sys_audit_logs WHERE event_type='tenant_provisioned' "
        "AND details LIKE '%test_oauth%'"
    ))
    host_db.db.conn.execute(text(
        "DELETE FROM core_memberships WHERE tenant_id LIKE 'test_oauth%'"
    ))
    host_db.db.conn.execute(text(
        "DELETE FROM core_tenants WHERE id LIKE 'test_oauth%'"
    ))
    host_db.db.conn.execute(text(
        "DELETE FROM core_users WHERE email LIKE '%@test_oauth.com'"
    ))
    host_db.commit()

cleanup_oauth_test_data()
print("üßπ Cleaned up previous test data\n")

In [None]:
#| hide

# ==========================================
# TEST 1: CSRF State Generation & Validation
# ==========================================

print("1Ô∏è‚É£ Testing CSRF State Protection...")

# Generate state
state1 = generate_oauth_state()
state2 = generate_oauth_state()
assert len(state1) == 32, "State should be 32-char UUID hex"
assert state1 != state2, "Each state should be unique"
print("   ‚úÖ State generation works")

# Valid state verification
mock_session = {'oauth_state': state1}
try:
    verify_oauth_state(mock_session, state1)
    print("   ‚úÖ Valid state verification works")
except ValueError:
    raise AssertionError("Valid state should not raise error")

# State should be cleared after verification
assert 'oauth_state' not in mock_session, "State should be cleared after use"
print("   ‚úÖ State cleared after verification")

# Invalid state (mismatch)
mock_session = {'oauth_state': state1}
try:
    verify_oauth_state(mock_session, state2)
    raise AssertionError("State mismatch should raise error")
except ValueError as e:
    assert "CSRF validation failed" in str(e)
    print("   ‚úÖ State mismatch detected")

# Missing state in session
mock_session = {}
try:
    verify_oauth_state(mock_session, state1)
    raise AssertionError("Missing state should raise error")
except ValueError as e:
    assert "No state found" in str(e)
    print("   ‚úÖ Missing state detected")

In [None]:
#| hide

# ==========================================
# TEST 2: New User Auto-Provisioning
# ==========================================

print("\n2Ô∏è‚É£ Testing New User Auto-Provisioning...")

# Simulate OAuth callback for new user
new_user = create_or_get_global_user(
    host_db=host_db,
    oauth_id='google_new_user_123',
    email='newuser@test_oauth.com'
)
host_db.commit()
assert new_user.email == 'newuser@test_oauth.com'
assert new_user.oauth_id == 'google_new_user_123'
print("   ‚úÖ GlobalUser created")

# Check no membership exists yet
membership = get_user_membership(host_db, new_user.id)
assert membership is None, "New user should have no membership"
print("   ‚úÖ No membership found (expected)")

# Auto-provision tenant (commits internally)
tenant_id = provision_new_user(host_db, new_user)
assert tenant_id is not None
print(f"   ‚úÖ Tenant provisioned: {tenant_id}")

# Verify membership created (need fresh read)
host_db.rollback()  # Clear any stale transaction state
membership = get_user_membership(host_db, new_user.id)
assert membership is not None, "Membership should exist after provisioning"
assert membership.tenant_id == tenant_id
assert membership.role == 'owner', "First user should be owner"
print("   ‚úÖ Membership created with 'owner' role")

# Verify tenant catalog entry
all_tenants = host_db.tenant_catalogs()
tenant = [t for t in all_tenants if t.id == tenant_id]
assert len(tenant) == 1, "Tenant should be registered"
assert "newuser's Workspace" in tenant[0].name
print("   ‚úÖ Tenant registered in catalog")

# Verify TenantUser in tenant database
username = new_user.email.split('@')[0]
tenant_name = f"{username}'s Workspace"
tenant_db = get_or_create_tenant_db(tenant_id, tenant_name)
from fh_saas.db_tenant import init_tenant_core_schema
core_tables = init_tenant_core_schema(tenant_db)
tenant_db.conn.rollback()
tenant_users = core_tables['tenant_users']()
tenant_user = [u for u in tenant_users if u.id == new_user.id]
assert len(tenant_user) == 1, "TenantUser should exist"
assert tenant_user[0].local_role == 'admin'
print("   ‚úÖ TenantUser created in tenant database")

# Verify audit log
all_logs = host_db.audit_logs()
log = [l for l in all_logs if l.target_id == tenant_id]
assert len(log) == 1, "Provisioning should be logged"
assert log[0].event_type == 'tenant_provisioned'
print("   ‚úÖ Audit log created")

In [None]:
#| hide

# ==========================================
# TEST 3: Returning User Login
# ==========================================

print("\n3Ô∏è‚É£ Testing Returning User Login...")

# Simulate OAuth callback for existing user
returning_user = create_or_get_global_user(
    host_db=host_db,
    oauth_id='google_new_user_123',  # Same oauth_id
    email='newuser@test_oauth.com'
)
host_db.commit()
assert returning_user.id == new_user.id, "Should return same user"
print("   ‚úÖ Existing user retrieved")

# Check membership still exists
membership = get_user_membership(host_db, returning_user.id)
assert membership is not None
assert membership.tenant_id == tenant_id
print("   ‚úÖ Membership found")

# Create session
mock_session = {}
create_user_session(mock_session, returning_user, membership)
assert mock_session['user_id'] == returning_user.id
assert mock_session['tenant_id'] == tenant_id
assert mock_session['tenant_role'] == 'owner'
print("   ‚úÖ Session created")

# Get current user from session
current_user = get_current_user(mock_session)
assert current_user is not None
assert current_user['email'] == 'newuser@test_oauth.com'
print("   ‚úÖ Current user retrieved from session")

In [None]:
#| hide

# ==========================================
# TEST 4: Cross-Tenant Access Prevention (CRITICAL)
# ==========================================

print("\n4Ô∏è‚É£ Testing Cross-Tenant Access Prevention (CRITICAL SECURITY)...")

# Create second user with their own tenant
user2 = create_or_get_global_user(
    host_db=host_db,
    oauth_id='google_user2_456',
    email='user2@test_oauth.com'
)
host_db.commit()
tenant2_id = provision_new_user(host_db, user2)
membership2 = get_user_membership(host_db, user2.id)
print(f"   ‚úÖ Second user and tenant created: {tenant2_id}")

# Verify user1 cannot access tenant2
can_access = verify_membership(host_db, new_user.id, tenant2_id)
assert not can_access, "User1 should NOT have access to tenant2"
print("   ‚úÖ User1 blocked from tenant2")

# Verify user2 cannot access tenant1
can_access = verify_membership(host_db, user2.id, tenant_id)
assert not can_access, "User2 should NOT have access to tenant1"
print("   ‚úÖ User2 blocked from tenant1")

# Try to access tenant2 with user1's session (should fail)
malicious_session = {
    'user_id': new_user.id,
    'tenant_id': tenant2_id,  # Wrong tenant!
    'email': 'newuser@test_oauth.com',
    'tenant_role': 'owner'
}

try:
    tenant_db = require_tenant_access(malicious_session)
    raise AssertionError("Cross-tenant access should be blocked!")
except PermissionError as e:
    assert "Access denied" in str(e)
    print("   ‚úÖ Cross-tenant access blocked by require_tenant_access()")

# Verify each user can only access their own tenant
session1 = {
    'user_id': new_user.id,
    'tenant_id': tenant_id,
    'email': 'newuser@test_oauth.com',
    'tenant_role': 'owner'
}
tenant_db1 = require_tenant_access(session1)
assert tenant_db1 is not None
print("   ‚úÖ User1 can access tenant1")

session2 = {
    'user_id': user2.id,
    'tenant_id': tenant2_id,
    'email': 'user2@test_oauth.com',
    'tenant_role': 'owner'
}
tenant_db2 = require_tenant_access(session2)
assert tenant_db2 is not None
print("   ‚úÖ User2 can access tenant2")

In [None]:
#| hide

# ==========================================
# TEST 5: System Admin Routing
# ==========================================

print("\n5Ô∏è‚É£ Testing System Admin Routing...")

# Create system admin user
from fh_saas.db_host import GlobalUser, timestamp, gen_id
admin_user = GlobalUser(
    id=gen_id(),
    email='admin@test_oauth.com',
    oauth_id='google_admin_789',
    is_sys_admin=True,
    created_at=timestamp(),
    last_login=timestamp()
)
host_db.global_users.insert(admin_user)
host_db.commit()
print("   ‚úÖ System admin user created")

# Admin should route to /admin/dashboard
redirect_url = route_user_after_login(admin_user, None)
assert redirect_url == '/admin/dashboard'
print("   ‚úÖ Admin routed to /admin/dashboard")

# Regular user should route to /dashboard
redirect_url = route_user_after_login(new_user, membership)
assert redirect_url == '/dashboard'
print("   ‚úÖ Regular user routed to /dashboard")

In [None]:
#| hide

# ==========================================
# TEST 6: Session Management
# ==========================================

print("\n6Ô∏è‚É£ Testing Session Management...")

# Test session creation
test_session = {}
create_user_session(test_session, new_user, membership)
assert 'user_id' in test_session
assert 'tenant_id' in test_session
assert 'login_at' in test_session
print("   ‚úÖ Session created with all required keys")

# Test get_current_user
user_info = get_current_user(test_session)
assert user_info is not None
assert user_info['email'] == 'newuser@test_oauth.com'
print("   ‚úÖ get_current_user() works")

# Test session clear
clear_session(test_session)
assert len(test_session) == 0
user_info = get_current_user(test_session)
assert user_info is None
print("   ‚úÖ Session cleared (logout)")

In [None]:
#| hide

# ==========================================
# CLEANUP & SUMMARY
# ==========================================

print("\n" + "="*60)
print("‚úÖ ALL OAUTH TESTS PASSED!")
print("="*60)
print("\nTests Completed:")
print("  1. CSRF state generation and validation ‚úÖ")
print("  2. New user auto-provisioning ‚úÖ")
print("  3. Returning user login ‚úÖ")
print("  4. Cross-tenant access prevention (CRITICAL) ‚úÖ")
print("  5. System admin routing ‚úÖ")
print("  6. Session management ‚úÖ")
print("\nüîí Security: Tenant isolation validated")
print("üéØ Ready for production integration")
print("\nüí° Next Steps:")
print("  1. Run nbdev_export to generate fhsaas/oauth.py")
print("  2. Register OAuth redirect URI in Google Cloud Console")
print("  3. Integrate route handlers into FastHTML app")
print("  4. Test with real Google OAuth flow")
print("  5. Implement user invitation system for multi-user tenants")

# Cleanup test data
cleanup_oauth_test_data()
print("\nüßπ Test data cleaned up")

## NEW: Using HostDatabase Singleton with Transaction Management

The refactored OAuth functions now use a singleton `HostDatabase` class for dependency injection and proper transaction management.

In [None]:
#| hide

# ==========================================
# EXAMPLE: NEW PATTERN WITH HostDatabase
# ==========================================

print("\nüîß NEW PATTERN: Using HostDatabase Singleton")
print("="*60)

# 1. Initialize HostDatabase singleton (once per application)
from fh_saas.db_host import HostDatabase
from fh_saas.utils_log import DatabaseHandler

host_db = HostDatabase.from_env()
log_handler = DatabaseHandler()  # Optional logging
print("‚úÖ HostDatabase singleton initialized")

# 2. Create or get user (with transaction management)
try:
    new_user = create_or_get_global_user(
        host_db=host_db,
        oauth_id='google_123',
        email='test@example.com',
        log_handler=log_handler
    )
    host_db.commit()  # Caller commits
    print(f"‚úÖ User created/retrieved: {new_user.email}")
except Exception as e:
    host_db.rollback()  # Auto-rolled back on error
    print(f"‚ùå User creation failed: {e}")

# 3. Get user membership (read-only, no transaction needed)
membership = get_user_membership(
    host_db=host_db,
    user_id=new_user.id,
    log_handler=log_handler
)
print(f"‚úÖ Membership: {membership.tenant_id if membership else 'None'}")

# 4. Provision new tenant (full transaction management inside function)
if not membership:
    try:
        tenant_id = provision_new_user(
            host_db=host_db,
            global_user=new_user,
            log_handler=log_handler
        )
        # No commit needed - provision_new_user commits internally
        print(f"‚úÖ Tenant provisioned: {tenant_id}")
    except Exception as e:
        # No rollback needed - provision_new_user rolls back internally
        print(f"‚ùå Provisioning failed: {e}")

# 5. Verify membership (read-only, security check)
has_access = verify_membership(
    host_db=host_db,
    user_id=new_user.id,
    tenant_id=membership.tenant_id if membership else 'fake_id',
    log_handler=log_handler
)
print(f"‚úÖ Access verification: {has_access}")

# 6. Check logs
if log_handler:
    logs = log_handler.get_logs()
    print(f"‚úÖ Logged {len(logs)} operations")
    for log in logs[-3:]:  # Show last 3 logs
        print(f"   [{log['level']}] {log['message']}")

print("\nüí° Key Benefits:")
print("  - Singleton pattern ensures single host DB connection")
print("  - Explicit transaction management (commit/rollback)")
print("  - Integrated logging for audit trail")
print("  - Read-only operations skip transactions for performance")
print("  - Module is independent of application-level objects")