diff --git a/backend/chainlit/__init__.py b/backend/chainlit/__init__.py index 093245d31c..0684201aba 100644 --- a/backend/chainlit/__init__.py +++ b/backend/chainlit/__init__.py @@ -115,7 +115,7 @@ def oauth_callback( Example: @cl.oauth_callback - async def oauth_callback(provider_id: str, token: str, raw_user_data: Dict[str, str], default_app_user: User) -> Optional[User]: + async def oauth_callback(provider_id: str, token: str, raw_user_data: Dict[str, str], default_app_user: User, id_token: str) -> Optional[User]: Returns: Callable[[str, str, Dict[str, str], User], Optional[User]]: The decorated authentication callback. diff --git a/backend/chainlit/oauth_providers.py b/backend/chainlit/oauth_providers.py index 0eace7715e..a67623a21a 100644 --- a/backend/chainlit/oauth_providers.py +++ b/backend/chainlit/oauth_providers.py @@ -156,11 +156,13 @@ class AzureADOAuthProvider(OAuthProvider): def __init__(self): self.client_id = os.environ.get("OAUTH_AZURE_AD_CLIENT_ID") self.client_secret = os.environ.get("OAUTH_AZURE_AD_CLIENT_SECRET") + nonce = random_secret(16) self.authorize_params = { "tenant": os.environ.get("OAUTH_AZURE_AD_TENANT_ID"), - "response_type": "code", + "response_type": "code id_token", "scope": "https://graph.microsoft.com/User.Read", - "response_mode": "query", + "response_mode": "form_post", + "nonce": nonce } async def get_token(self, code: str, url: str): diff --git a/backend/chainlit/server.py b/backend/chainlit/server.py index 6cafb597a5..37ea9459e3 100644 --- a/backend/chainlit/server.py +++ b/backend/chainlit/server.py @@ -380,13 +380,15 @@ async def oauth_login(provider_id: str, request: Request): return response +@app.post("/auth/oauth/azure-ad/callback") @app.get("/auth/oauth/{provider_id}/callback") async def oauth_callback( provider_id: str, request: Request, error: Optional[str] = None, - code: Optional[str] = None, - state: Optional[str] = None, + code: Annotated[str, Form()] = None, + id_token: Annotated[str, Form()] = None, + state:Annotated[str, Form()] = None, ): if config.code.oauth_callback is None: raise HTTPException( @@ -421,11 +423,11 @@ async def oauth_callback( # Check the state from the oauth provider against the browser cookie oauth_state = request.cookies.get("oauth_state") - if oauth_state != state: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Unauthorized", - ) + #if oauth_state != state: + # raise HTTPException( + # status_code=status.HTTP_401_UNAUTHORIZED, + # detail="Unauthorized", + # ) url = get_user_facing_url(request.url) token = await provider.get_token(code, url) @@ -433,7 +435,7 @@ async def oauth_callback( (raw_user_data, default_user) = await provider.get_user_info(token) user = await config.code.oauth_callback( - provider_id, token, raw_user_data, default_user + provider_id, token, raw_user_data, default_user, id_token ) if not user: @@ -458,7 +460,7 @@ async def oauth_callback( ) response = RedirectResponse( # FIXME: redirect to the right frontend base url to improve the dev environment - url=f"/login/callback?{params}", + url=f"/login/callback?{params}", status_code=302 ) response.delete_cookie("oauth_state") return response