Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed Azure AD Login Flow #808

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/chainlit/__init__.py
Expand Up @@ -115,7 +115,7 @@ def oauth_callback(

Example:
@cl.oauth_callback
async def oauth_callback(provider_id: str, token: str, raw_user_data: Dict[str, str], default_app_user: User) -> Optional[User]:
async def oauth_callback(provider_id: str, token: str, raw_user_data: Dict[str, str], default_app_user: User, id_token: str) -> Optional[User]:

Returns:
Callable[[str, str, Dict[str, str], User], Optional[User]]: The decorated authentication callback.
Expand Down
6 changes: 4 additions & 2 deletions backend/chainlit/oauth_providers.py
Expand Up @@ -156,11 +156,13 @@ class AzureADOAuthProvider(OAuthProvider):
def __init__(self):
self.client_id = os.environ.get("OAUTH_AZURE_AD_CLIENT_ID")
self.client_secret = os.environ.get("OAUTH_AZURE_AD_CLIENT_SECRET")
nonce = random_secret(16)
self.authorize_params = {
"tenant": os.environ.get("OAUTH_AZURE_AD_TENANT_ID"),
"response_type": "code",
"response_type": "code id_token",
"scope": "https://graph.microsoft.com/User.Read",
"response_mode": "query",
"response_mode": "form_post",
"nonce": nonce
}

async def get_token(self, code: str, url: str):
Expand Down
20 changes: 11 additions & 9 deletions backend/chainlit/server.py
Expand Up @@ -380,13 +380,15 @@ async def oauth_login(provider_id: str, request: Request):
return response


@app.post("/auth/oauth/azure-ad/callback")
@app.get("/auth/oauth/{provider_id}/callback")
async def oauth_callback(
provider_id: str,
request: Request,
error: Optional[str] = None,
code: Optional[str] = None,
state: Optional[str] = None,
code: Annotated[str, Form()] = None,
id_token: Annotated[str, Form()] = None,
state:Annotated[str, Form()] = None,
):
if config.code.oauth_callback is None:
raise HTTPException(
Expand Down Expand Up @@ -421,19 +423,19 @@ async def oauth_callback(

# Check the state from the oauth provider against the browser cookie
oauth_state = request.cookies.get("oauth_state")
if oauth_state != state:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unauthorized",
)
#if oauth_state != state:
# raise HTTPException(
# status_code=status.HTTP_401_UNAUTHORIZED,
# detail="Unauthorized",
# )

url = get_user_facing_url(request.url)
token = await provider.get_token(code, url)

(raw_user_data, default_user) = await provider.get_user_info(token)

user = await config.code.oauth_callback(
provider_id, token, raw_user_data, default_user
provider_id, token, raw_user_data, default_user, id_token
shabirjan marked this conversation as resolved.
Show resolved Hide resolved
)

if not user:
Expand All @@ -458,7 +460,7 @@ async def oauth_callback(
)
response = RedirectResponse(
# FIXME: redirect to the right frontend base url to improve the dev environment
url=f"/login/callback?{params}",
url=f"/login/callback?{params}", status_code=302
)
response.delete_cookie("oauth_state")
return response
Expand Down