Skip to content
Merged
69 changes: 68 additions & 1 deletion mcpgateway/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from datetime import datetime, timezone
import hashlib
import logging
from typing import Generator, Never, Optional
from typing import Any, Dict, Generator, Never, Optional
import uuid

# Third-Party
Expand All @@ -26,6 +26,7 @@
from mcpgateway.config import settings
from mcpgateway.db import EmailUser, SessionLocal
from mcpgateway.plugins.framework import get_plugin_manager, GlobalContext, HttpAuthResolveUserPayload, HttpHeaderPayload, HttpHookType, PluginViolationError
from mcpgateway.services.team_management_service import TeamManagementService # pylint: disable=import-outside-toplevel
from mcpgateway.utils.verify_credentials import verify_jwt_token

# Security scheme
Expand Down Expand Up @@ -53,6 +54,67 @@ def get_db() -> Generator[Session, Never, None]:
db.close()


async def get_team_from_token(payload: Dict[str, Any], db: Session) -> Optional[str]:
"""
Extract the team ID from an authentication token payload. If the token does
not include a team, the user's personal team is retrieved from the database.

This function behaves as follows:

1. If `payload["teams"]` exists and is non-empty:
Returns the first team ID from that list.

2. If no teams are present in the payload:
Fetches the user's teams (using `payload["sub"]` as the user email) and
returns the ID of the personal team, if one exists.

Args:
payload (Dict[str, Any]):
The token payload. Expected fields:
- "sub" (str): The user's unique identifier (email).
- "teams" (List[str], optional): List containing team ID.
db (Session):
SQLAlchemy database session used to query team data.

Returns:
Optional[str]:
The resolved team ID. Returns `None` if no team can be determined
either from the payload or from the database.

Examples:
>>> import sys, asyncio
>>> from unittest.mock import AsyncMock, MagicMock
>>>
>>> # --- Mock setup for both tests ---
>>> mock_db = MagicMock()
>>>
>>> # Patch TeamManagementService import path dynamically
>>> mock_team_service = AsyncMock()
>>> mock_team = MagicMock(id="personal_team_123", is_personal=True)
>>> mock_team_service.get_user_teams.return_value = [mock_team]
>>>
>>> sys.modules['mcpgateway.services.team_management_service'] = type(sys)("dummy")
>>> sys.modules['mcpgateway.services.team_management_service'].TeamManagementService = lambda db: mock_team_service
>>>
>>> # --- Case 1: Token has team ---
>>> payload = {"sub": "user@example.com", "teams": ["team_456"]}
>>> asyncio.run(get_team_from_token(payload, mock_db))
'team_456'
>>> del sys.modules["mcpgateway.services.team_management_service"]
"""
team_id = payload.get("teams")[0] if payload.get("teams") else None
user_email = payload.get("sub")
# If no team found in token, get user's personal team
if not team_id:

team_service = TeamManagementService(db)
user_teams = await team_service.get_user_teams(user_email, include_personal=True)
personal_team = next((team for team in user_teams if team.is_personal), None)
team_id = personal_team.id if personal_team else None

return team_id


async def get_current_user(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(bearer_scheme),
db: Session = Depends(get_db),
Expand Down Expand Up @@ -227,6 +289,11 @@ async def get_current_user(
# Log the error but don't fail authentication for admin tokens
logger.warning(f"Token revocation check failed for JTI {jti}: {revoke_check_error}")

# Check team level token, if applicable. If public token, then will be defaulted to personal team.
team_id = await get_team_from_token(payload, db)
if request and hasattr(request, "state"):
request.state.team_id = team_id

except HTTPException:
# Re-raise HTTPException from verify_jwt_token (handles expired/invalid tokens)
raise
Expand Down
Loading
Loading