In [None]:
# | default_exp sso

In [None]:
# | export

import os
import secrets
from datetime import datetime, timedelta
from typing import *

import requests
from fastapi import HTTPException, Request, status
from pydantic import BaseModel
from requests_oauthlib import OAuth2Session
from sqlalchemy.exc import NoResultFound
from sqlmodel import select

import airt_service.sanitizer
from airt_service.db.models import (
    SSO,
    SSOProtocol,
    SSOProvider,
    User,
    get_session_with_context,
)
from airt_service.errors import ERRORS
from airt_service.helpers import commit_or_rollback

In [None]:
import secrets
from contextlib import contextmanager
from urllib.parse import parse_qs, urlparse

import pytest
from sqlmodel import select

from airt_service.db.models import (
    create_user_for_testing,
    get_session,
    get_session_with_context,
)
from airt_service.users import EnableSSORequest, disable_sso, enable_sso

23-03-10 10:25:12.348 [INFO] airt.executor.subcommand: Module loaded.


In [None]:
# | exporti

# Google APi discovery URL
GOOGLE_DISCOVERY_URL = "https://accounts.google.com/.well-known/openid-configuration"

# constants
SSO_SUCCESS_MSG = "Authentication successful. Please close the browser."
SESSION_TIME_LIMIT = 10  # mins
SSO_CONFIG: Dict[str, Any] = {
    "google": {"scope": [
        "https://www.googleapis.com/auth/userinfo.email", 
        "https://www.googleapis.com/auth/userinfo.profile",
        "openid"
    ]},
    "github": {
        "scope": "user:email",
        "authorization_endpoint": "https://github.com/login/oauth/authorize",
        "token_endpoint": "https://github.com/login/oauth/access_token",
        "userinfo_endpoint": "https://api.github.com/user/emails",
    },
}

In [None]:
# | exporti


def _generate_callback_url() -> str:
    """Generate callback URL for the SSO provider

    This is the URL the user will be redirected to after successful user authentication.

    Returns:
        The generated callback URL
    """
    domain_in_env = os.environ["DOMAIN"]
    domain = (
        "http://127.0.0.1:6006"
        if (domain_in_env == "localhost" or "airt-service" in domain_in_env)
        else f"https://{domain_in_env}"
    )

    callback_url = f"{domain}/sso/callback"

    return callback_url

In [None]:
actual = _generate_callback_url()

display(actual)
assert "http://127.0.0.1:6006/sso/callback" == actual

'http://127.0.0.1:6006/sso/callback'

In [None]:
@contextmanager
def non_local_domain():
    try:
        actual_domain = os.environ["DOMAIN"]
        os.environ["DOMAIN"] = "not_localhost"

        yield
    finally:
        os.environ["DOMAIN"] = actual_domain

In [None]:
with non_local_domain():
    actual = _generate_callback_url()

    display(actual)
    assert actual == "https://not_localhost/sso/callback"

'https://not_localhost/sso/callback'

In [None]:
# | exporti


def _get_google_provider_cfg(api_uri: Optional[str] = None) -> Union[str, dict]:
    """Get google's OpenID Connect configuration

    This configuration includes the URIs of the authorization, token, revocation, userinfo, and public-keys endpoints.

    Args:
        api_uri: API endpoint uri to return from the configuration. If not set then the default value **None**
            will be used to return all the API uri endpoints.

    Returns:
        The google's OpenID Connect endpoint(s)
    """
    try:
        google_provider_cfg = requests.get(GOOGLE_DISCOVERY_URL).json()
    except Exception as e:
        raise HTTPException(
            status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
            detail=ERRORS["SERVICE_UNAVAILABLE"],
        )

    return google_provider_cfg[api_uri] if api_uri is not None else google_provider_cfg  # type: ignore

In [None]:
full_cfg = _get_google_provider_cfg()
assert "token_endpoint" in full_cfg.keys()
assert "authorization_endpoint" in full_cfg.keys()

api_uri = "authorization_endpoint"
actual = _get_google_provider_cfg(api_uri)
display(actual)
assert full_cfg[api_uri] == actual

'https://accounts.google.com/o/oauth2/v2/auth'

In [None]:
# | exporti


def _get_authorization_url_and_nonce(
    sso_provider: str, username: str, nonce: str
) -> Tuple[str, str]:
    """Get authorization url and nonce

    Args:
        sso_provider: The name of the sso provider
        nonce: cryptographically strong random string
        username: username to append in the redirection url

    Returns:
        The authorization url and the nonce
    """

    callback_url = _generate_callback_url()
    nonce_with_username = f"{nonce}_{username}"

    client = OAuth2Session(
        os.environ[f"{sso_provider.upper()}_CLIENT_ID"],
        scope=SSO_CONFIG[f"{sso_provider}"]["scope"],
        redirect_uri=callback_url,
        state=nonce_with_username,
    )

    authorization_endpoint = (
        _get_google_provider_cfg(api_uri="authorization_endpoint")
        if sso_provider == "google"
        else SSO_CONFIG[f"{sso_provider}"]["authorization_endpoint"]
    )

    authorization_url, nonce_with_username = client.authorization_url(
        authorization_endpoint, prompt="select_account"
    )

    return authorization_url, nonce_with_username

In [None]:
sso_provider = "google"
username = "random_username"
nonce = secrets.token_hex()
authorization_url, nonce_with_username = _get_authorization_url_and_nonce(
    sso_provider, username, nonce
)

parse_result = urlparse(authorization_url)
dict_result = parse_qs(parse_result.query)

display(authorization_url)

api_uri = "authorization_endpoint"
expected_authorization_base_url = _get_google_provider_cfg(api_uri)

assert expected_authorization_base_url in authorization_url
assert dict_result["state"][0] == f"{nonce}_{username}"
assert dict_result["redirect_uri"][0] == _generate_callback_url()

'https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=842138153914-6kvm51cpin7iocg3nrsnl44s3d24u047.apps.googleusercontent.com&redirect_uri=http%3A%2F%2F127.0.0.1%3A6006%2Fsso%2Fcallback&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.profile+openid&state=0fd3a4312ecbf54afc4373f56d692010a59993fd2e1dc6454ab0cf4702cf1f4d_random_username&prompt=select_account'

In [None]:
sso_provider = "github"
username = "random_username"
nonce = secrets.token_hex()
authorization_url, nonce_with_username = _get_authorization_url_and_nonce(
    sso_provider, username, nonce
)

parse_result = urlparse(authorization_url)
dict_result = parse_qs(parse_result.query)

display(authorization_url)

assert SSO_CONFIG[f"{sso_provider}"]["authorization_endpoint"] in authorization_url
assert dict_result["state"][0] == f"{nonce}_{username}"
assert dict_result["redirect_uri"][0] == _generate_callback_url()

'https://github.com/login/oauth/authorize?response_type=code&client_id=a0f58d9e50375190dbf0&redirect_uri=http%3A%2F%2F127.0.0.1%3A6006%2Fsso%2Fcallback&scope=user%3Aemail&state=6a2d6f2c383776c00733f09d21768ca108dcc6571c8e7ff80e0ec1ca2d8379a4_random_username&prompt=select_account'

In [None]:
# | export


def get_valid_sso_providers() -> List[str]:
    """Get valid SSO proiders

    Returns:
        The list of valid SSO providers
    """
    return [e.value for e in SSOProvider]

In [None]:
actual = get_valid_sso_providers()
assert actual == ["google", "github"]
actual

['google', 'github']

In [None]:
# | exporti


def get_sso_if_enabled_for_user(user_id: int, sso_provider: str) -> Optional[SSO]:
    """Check if the given sso provider is enabled for the user

    Args:
        user_id: The user_id for whom the SSO provider status needs be checked
        sso_provider: The name of the SSO provider

    Returns:
        The SSO object if the given sso provider is enabled for the user, else None
    """
    with get_session_with_context() as session:
        try:
            sso = session.exec(
                select(SSO)
                .where(SSO.user_id == user_id)
                .where(SSO.sso_provider == sso_provider)
            ).one()

            if sso.disabled:
                sso = None

        except NoResultFound:
            sso = None

        return sso  # type: ignore

In [None]:
with get_session_with_context() as session:
    user = session.exec(
        select(User).where(User.username == create_user_for_testing())
    ).one()
    sso_provider = "google"
    actual = get_sso_if_enabled_for_user(user.id, sso_provider)
    display(actual)
    assert not actual

None

In [None]:
# context manager to create a SSO enabled user for GOOGLE


@contextmanager
def create_sso_user(
    sso_provider: str = "google", sso_email: str = "random_email_id@mail.com"
):
    with get_session_with_context() as session:
        sso_enabled_user = create_user_for_testing()
        user = session.exec(select(User).where(User.username == sso_enabled_user)).one()
        try:
            enable_sso_request = EnableSSORequest(
                sso_provider=sso_provider, sso_email=sso_email
            )
            actual = enable_sso(
                enable_sso_request=enable_sso_request, user=user, session=session
            )
            display(actual)
            yield user, session
        finally:
            # deactivate User
            with commit_or_rollback(session):
                user.disabled = True
                session.add(user)
                session.commit()
            assert user.disabled

In [None]:
with create_sso_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]

    sso_provider = "google"
    actual = get_sso_if_enabled_for_user(user.id, sso_provider)
    actual.sso_provider == "google"
    actual.sso_email == "random_email_id@mail.com"
    actual

SSO()

In [None]:
with create_sso_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    sso_provider = "google"
    # disable SSO
    actual = disable_sso(
        user_uuid_or_name=str(user.uuid),
        sso_provider=sso_provider,
        user=user,
        session=session,
    )
    display(actual)
    assert actual.sso_provider == sso_provider
    assert actual.user_id == user.id
    # Now access the SSO record from db
    actual = get_sso_if_enabled_for_user(user.id, sso_provider)
    display(actual)
    assert not actual

SSO()

SSO()

None

In [None]:
# | export


class SSOAuthURL(BaseModel):
    """A base class for creating authorization URL for the provider

    Args:
        authorization_url: The generated authorization URL for the provider
    """

    authorization_url: str

In [None]:
# | export


def initiate_sso_flow(
    username: str, sso_provider: str, nonce: str, sso: SSO
) -> SSOAuthURL:
    """Initiate SSO flow and return provider authorization URL

    Args:
        username: Username as a string
        sso_provider: The name of the SSO provider
        nonce: A cryptographically strong random string
        sso: SSO object in session

    Returns:
        The authorization URL for the SSO provider
    """
    # Step 1: Generate authorization_url with username and nonce added to the query params
    authorization_url, nonce_with_username = _get_authorization_url_and_nonce(
        sso_provider, username, nonce
    )

    # Step 2: store the nonce and created_at in the sso_protocol table
    with get_session_with_context() as session:
        sso_protocol = SSOProtocol(**dict(nonce=nonce, created_at=datetime.utcnow()))
        sso_protocol.sso = session.merge(sso)
        session.add(sso_protocol)
        session.commit()

    # Step 3: return the redirect URL to the user
    return SSOAuthURL(authorization_url=authorization_url)

In [None]:
with create_sso_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    sso_provider = "google"

    sso = get_sso_if_enabled_for_user(user.id, sso_provider)
    nonce = secrets.token_hex()
    actual = initiate_sso_flow(user.username, sso_provider, sso=sso, nonce=nonce)
    display(actual)
    assert f"{nonce}_{user.username}" in actual.authorization_url

    sso = get_sso_if_enabled_for_user(user.id, sso_provider)
    nonce = secrets.token_hex()
    actual = initiate_sso_flow(user.username, sso_provider, sso=sso, nonce=nonce)
    display(actual)
    assert f"{nonce}_{user.username}" in actual.authorization_url

    sso = get_sso_if_enabled_for_user(user.id, sso_provider)
    nonce = secrets.token_hex()
    actual = initiate_sso_flow(user.username, sso_provider, sso=sso, nonce=nonce)
    display(actual)
    assert f"{nonce}_{user.username}" in actual.authorization_url

    # check if the record exists in DB
    sso = get_sso_if_enabled_for_user(user.id, sso_provider)
    with get_session_with_context() as session:
        sso_protocol = session.exec(
            select(SSOProtocol)
            .where(SSOProtocol.sso_id == sso.id)
            .where(SSOProtocol.nonce == nonce)
        ).one()
    display(sso_protocol)

SSO()

SSOAuthURL(authorization_url='https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=842138153914-6kvm51cpin7iocg3nrsnl44s3d24u047.apps.googleusercontent.com&redirect_uri=http%3A%2F%2F127.0.0.1%3A6006%2Fsso%2Fcallback&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.profile+openid&state=a6c16476592868111d68f0885e54f66b2934a063d63a620e58c732ce25b3f0ca_vkylrpnbch&prompt=select_account')

SSOAuthURL(authorization_url='https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=842138153914-6kvm51cpin7iocg3nrsnl44s3d24u047.apps.googleusercontent.com&redirect_uri=http%3A%2F%2F127.0.0.1%3A6006%2Fsso%2Fcallback&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.profile+openid&state=340dc3ad383f0be8c73a14fdce752d69160a43b1f88630391a132a8a85e8d785_vkylrpnbch&prompt=select_account')

SSOAuthURL(authorization_url='https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=842138153914-6kvm51cpin7iocg3nrsnl44s3d24u047.apps.googleusercontent.com&redirect_uri=http%3A%2F%2F127.0.0.1%3A6006%2Fsso%2Fcallback&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.profile+openid&state=7825f9203f580fc8a565d5d9b2c77951826e1b0ddeb71909c3241c660c166e3c_vkylrpnbch&prompt=select_account')

SSOProtocol(id=103, created_at=datetime.datetime(2023, 3, 10, 10, 25, 15), error=None, is_sso_successful=False, nonce='7825f9203f580fc8a565d5d9b2c77951826e1b0ddeb71909c3241c660c166e3c', sso_id=380)

In [None]:
sso_provider = "github"
sso_email = "random.mail@gmail.com"
with create_sso_user(
    sso_provider=sso_provider, sso_email=sso_email
) as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    sso_provider = sso_provider

    sso = get_sso_if_enabled_for_user(user.id, sso_provider)
    nonce = secrets.token_hex()
    actual = initiate_sso_flow(user.username, sso_provider, sso=sso, nonce=nonce)
    display(actual)
    assert f"{nonce}_{user.username}" in actual.authorization_url

SSO()

SSOAuthURL(authorization_url='https://github.com/login/oauth/authorize?response_type=code&client_id=a0f58d9e50375190dbf0&redirect_uri=http%3A%2F%2F127.0.0.1%3A6006%2Fsso%2Fcallback&scope=user%3Aemail&state=284d99c8b1eb5a08031b2ac4086fe5c955f5c3d0448afb60adf132b947c685da_ehxgpbhhfv&prompt=select_account')

In [None]:
sso_provider = "github"
sso_email = "random.mail@gmail.com"
with create_sso_user(
    sso_provider=sso_provider, sso_email=sso_email
) as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]

    sso = get_sso_if_enabled_for_user(user.id, sso_provider)
    nonce = secrets.token_hex()
    actual = initiate_sso_flow(user.username, sso_provider, sso=sso, nonce=nonce)
    display(actual)
    assert f"{nonce}_{user.username}" in actual.authorization_url
    assert (
        SSO_CONFIG[f"{sso_provider}"]["authorization_endpoint"]
        in actual.authorization_url
    )

    # check if the record exists in DB
    sso = get_sso_if_enabled_for_user(user.id, sso_provider)
    with get_session_with_context() as session:
        sso_protocol = session.exec(
            select(SSOProtocol)
            .where(SSOProtocol.sso_id == sso.id)
            .where(SSOProtocol.nonce == nonce)
        ).one()
    display(sso)
    display(sso_protocol)

SSO()

SSOAuthURL(authorization_url='https://github.com/login/oauth/authorize?response_type=code&client_id=a0f58d9e50375190dbf0&redirect_uri=http%3A%2F%2F127.0.0.1%3A6006%2Fsso%2Fcallback&scope=user%3Aemail&state=c6d073e6abf26b565ccbe650db1f84898fd5482eb61c372eb2ebdd61e7ab08b1_vkaegidcoy&prompt=select_account')

SSO(disabled=False, sso_email='random.mail@gmail.com', user_id=1057, sso_provider=<SSOProvider.github: 'github'>, id=382)

SSOProtocol(id=105, created_at=datetime.datetime(2023, 3, 10, 10, 25, 16), error=None, is_sso_successful=False, nonce='c6d073e6abf26b565ccbe650db1f84898fd5482eb61c372eb2ebdd61e7ab08b1', sso_id=382)

In [None]:
# | exporti


def _get_token_and_user_url(sso_provider: str) -> Tuple[str, str]:
    """Get token and the user info URL endpoints for the given provider

    Args:
        sso_provider: Name of the SSO provider

    Returns:
        The token and the user info URL endpoints for the given provider
    """
    if sso_provider == "google":
        token_endpoint = _get_google_provider_cfg(api_uri="token_endpoint")
        userinfo_endpoint = _get_google_provider_cfg(api_uri="userinfo_endpoint")
    else:
        token_endpoint = SSO_CONFIG[f"{sso_provider}"]["token_endpoint"]
        userinfo_endpoint = SSO_CONFIG[f"{sso_provider}"]["userinfo_endpoint"]

    return token_endpoint, userinfo_endpoint  # type: ignore

In [None]:
sso_provider = "google"
token_endpoint, userinfo_endpoint = _get_token_and_user_url(sso_provider)
display(token_endpoint)
display(userinfo_endpoint)

assert "/token" in token_endpoint
assert "/userinfo" in userinfo_endpoint

'https://oauth2.googleapis.com/token'

'https://openidconnect.googleapis.com/v1/userinfo'

In [None]:
sso_provider = "github"
token_endpoint, userinfo_endpoint = _get_token_and_user_url(sso_provider)
display(token_endpoint)
display(userinfo_endpoint)

assert SSO_CONFIG[f"{sso_provider}"]["token_endpoint"] == token_endpoint
assert SSO_CONFIG[f"{sso_provider}"]["userinfo_endpoint"] == userinfo_endpoint

'https://github.com/login/oauth/access_token'

'https://api.github.com/user/emails'

In [None]:
# | exporti


def get_user_info_from_provider(
    url: str, nonce_with_username: str, sso_provider: str
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
    """Get user info from the provider

    This function exchanges the authorization code in the response for an access token
    to access user details from the provider.

    Args:
        url: callback url from google
        nonce_with_username: The nonce created by the client along with the username
        sso_provider: Name of the SSO provider

    Returns:
        The user's information registered with the SSO provider
    """
    redirect_uri = _generate_callback_url()
    client = OAuth2Session(
        os.environ[f"{sso_provider.upper()}_CLIENT_ID"],
        state=nonce_with_username,
        redirect_uri=redirect_uri,
    )

    token_endpoint, userinfo_endpoint = _get_token_and_user_url(sso_provider)
    try:
        token = client.fetch_token(
            token_endpoint,
            client_secret=os.environ[f"{sso_provider.upper()}_CLIENT_SECRET"],
            authorization_response=url,
        )
    except Exception as e:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["SSO_CSRF_WARNING"],
        )

    response: Union[Dict[str, Any], List[Dict[str, Any]]] = client.get(
        userinfo_endpoint
    ).json()
    return response

In [None]:
url = "/sso/google/callback?state=a1b2c3_random_username&code=123456&scope=email+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email+openid&authuser=0&prompt=none"
sso_provider = "google"
nonce_with_username = "a1b2c3_random_username"
with pytest.raises(HTTPException) as e:
    get_user_info_from_provider(url, nonce_with_username, sso_provider)

err = str(e.value.detail)
assert "Request check failed" in err, err
err

'Request check failed: State not equal in request and response. For your protection, access to this resource is secured against CSRF. Please re-generate the authentication URL and initiate the SSO login process again.'

In [None]:
# | export


def get_sso_protocol_and_email(
    username: str, nonce: str, sso_provider: str
) -> Tuple[SSOProtocol, str]:
    """Get SSO protocol and SSO email details

    Args:
        username: username to append in the redirection url
        nonce: cryptographically strong random string
        sso_provider: Name of the SSO provider

    Returns:
        The record from sso protocol table and the email address used to enable the sso provider

    Raises:
        HTTPException: If the username is incorrect
        HTTPException: If the SSO is not yet enabled for the provider
        HTTPException: If the received in the callback didn't match
        HTTPException: If the session is timed out
        HTTPException: If the email address used for SSO authentication didn't match with the one used while enabling the SSO
    """
    # Step 1: Check if username exists
    with get_session_with_context() as session:
        try:
            user = session.exec(select(User).where(User.username == username)).one()
        except NoResultFound:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERRORS["INCORRECT_USERNAME"],
            )
    # Step 2: Check if the sso provider is enabled for the user
    sso = get_sso_if_enabled_for_user(user.id, sso_provider)
    if sso is None:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["SSO_NOT_ENABLED_FOR_SERVICE"],
        )

    # Step 3: Check if sso_protocol table has a record
    with get_session_with_context() as session:
        try:
            sso_protocol = session.exec(
                select(SSOProtocol)
                .where(SSOProtocol.sso_id == sso.id)
                .where(SSOProtocol.nonce == nonce)
            ).one()
        except NoResultFound:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail=ERRORS["SSO_CSRF_WARNING"],
            )

    # Step 4: Check if the sso_protocol already has errored out
    if sso_protocol.error is not None:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=sso_protocol.error,
        )

    return sso_protocol, sso.sso_email

In [None]:
sso_provider = "google"
invalid_nonce = "invalid_nonce"
invalid_username = "invalid_username"
with pytest.raises(HTTPException) as e:
    get_sso_protocol_and_email(
        username=invalid_username, nonce=invalid_nonce, sso_provider=sso_provider
    )
assert "Incorrect username" in e.value.detail
e.value.detail

'Incorrect username. Please try again.'

In [None]:
with create_sso_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    sso_provider = "github"
    invalid_nonce = "invalid_nonce"
    with pytest.raises(HTTPException) as e:
        get_sso_protocol_and_email(
            username=user.username, nonce=invalid_nonce, sso_provider=sso_provider
        )
assert "SSO is not enabled for the provider." in e.value.detail
e.value.detail

SSO()

'SSO is not enabled for the provider.'

In [None]:
with create_sso_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    sso_provider = "google"
    invalid_nonce = "invalid_nonce"
    with pytest.raises(HTTPException) as e:
        get_sso_protocol_and_email(
            username=user.username, nonce=invalid_nonce, sso_provider=sso_provider
        )
assert "Request check failed:" in e.value.detail
e.value.detail

SSO()

'Request check failed: State not equal in request and response. For your protection, access to this resource is secured against CSRF. Please re-generate the authentication URL and initiate the SSO login process again.'

In [None]:
# | export

def update_user_id_in_sso_table(trial_sso_username: str, user_id_to_update: int) -> None:
    """Update the user_id in the SSO table
    
    Args:
        trial_sso_username: Name of the trial SSO user
        user_id_to_update: User id to update in the SSO table
        
    Raises:
        HTTPException: If the username is incorrect or no records found for the user in the SSO table
    """
    with get_session_with_context() as session:
        try:
            trial_sso_user = session.exec(
                select(User).where(User.username == trial_sso_username)
            ).one()
            sso = session.exec(
                select(SSO).where(SSO.user_id == trial_sso_user.id)
            ).one()
        
        except NoResultFound:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERRORS["SSO_GENERIC_ERROR"],
            )
        
        sso.user_id = user_id_to_update
        with commit_or_rollback(session):
            session.add(sso)


In [None]:
with create_sso_user() as user_and_session:
    sso_enabled_user = user_and_session[0]
    session = user_and_session[1]
    session.exec(select(SSO).where(SSO.user_id == sso_enabled_user.id)).one()

with get_session_with_context() as session:
    test_username = create_user_for_testing()
    test_user = session.exec(select(User).where(User.username == test_username)).one()
    assert not session.exec(select(SSO).where(SSO.user_id == test_user.id)).first()
    
    print(f"{test_user.id=}")
    update_user_id_in_sso_table(sso_enabled_user.username, test_user.id)

with get_session_with_context() as session:
    sso = session.exec(select(SSO).where(SSO.user_id == test_user.id)).one()
    print(f"{sso.user_id=}")

SSO()

test_user.id=1064
sso.user_id=1064


In [None]:
with pytest.raises(HTTPException) as e:
    with get_session_with_context() as session:
        test_username = create_user_for_testing()
        test_user = session.exec(select(User).where(User.username == test_username)).one()
        assert not session.exec(select(SSO).where(SSO.user_id == test_user.id)).first()

        invalid_username = "invalid_username"
        update_user_id_in_sso_table(invalid_username, test_user.id)

print(e.value.detail)

Something went wrong. Please re-generate the authentication URL and initiate the SSO login process again.


In [None]:
# | export


def disable_trial_user(username: str) -> None:
    """Disable the trial user
    
    Args:
        username: Username of the user to disable 
    """
    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.username == username)).one()
        user.disabled = True
        with commit_or_rollback(session):
            session.add(user)

In [None]:
with get_session_with_context() as session:
    test_username = create_user_for_testing()
    test_user = session.exec(select(User).where(User.username == test_username)).one()
    print(f"{test_user.disabled=}")
    assert not test_user.disabled
    
disable_trial_user(test_username)

with get_session_with_context() as session:
    test_user = session.exec(select(User).where(User.username == test_username)).one()
    print(f"{test_user.disabled=}")
    assert test_user.disabled

test_user.disabled=False
test_user.disabled=True


In [None]:
# | export


def update_user_info_in_db(
    sso_signup_trial_username: str,
    user_info_from_provider: Dict[str, str],
) -> None:
    """Update user information retrived from an external SSO provider in the database.

    Args:
        sso_signup_trial_username: The username of the SSO signup trial user.
        user_info_from_provider: User information retrieved from an external SSO provider.
    """
    existing_user = False
    with get_session_with_context() as session:
        try:
            user = session.exec(
                select(User).where(User.email == user_info_from_provider["email"])
            ).one()
            
            existing_user = True
            update_user_id_in_sso_table(sso_signup_trial_username, user.id)

        except NoResultFound:
            user = session.exec(
                select(User).where(User.username == sso_signup_trial_username)
            ).one()

            random_number = secrets.SystemRandom().randint(0, 1000)
            updated_username = f'{user_info_from_provider["name"].replace(" ", "_").lower()}_{random_number}'
            
            user.username = updated_username
            user.first_name = user_info_from_provider["given_name"]
            user.last_name = user_info_from_provider["family_name"]
            user.email = user_info_from_provider["email"]

        user.sso_signup_trial_username = sso_signup_trial_username
        
        with commit_or_rollback(session):
            session.add(user)
            
        # disable the trial user
        if existing_user:
            disable_trial_user(sso_signup_trial_username)

In [None]:
# Update first time user details.
with create_sso_user() as user_and_session:
    sso_enabled_user = user_and_session[0]
    session = user_and_session[1]

    user_info_from_provider = {
        "sub": "10111231231233982347423",
        "name": f"John Doe {sso_enabled_user.username}",
        "given_name": "John",
        "family_name": "Doe",
        "picture": "https://lh3.googleusercontent.com/a/someRandomString=s96-c",
        "email": f"john.doe.{sso_enabled_user.username}@gmail.com",
        "email_verified": True,
        "locale": "en",
    }

    update_user_info_in_db(sso_enabled_user.username, user_info_from_provider)

    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.sso_signup_trial_username == sso_enabled_user.username)).one()
        print(user)

        assert user_info_from_provider["name"].replace(" ", "_").lower() in user.username, user.username
        assert user.first_name == user_info_from_provider["given_name"]
        assert user.last_name == user_info_from_provider["family_name"]
        assert user.email == user_info_from_provider["email"]
        assert user.sso_signup_trial_username == sso_enabled_user.username

        sso = session.exec(select(SSO).where(SSO.user_id == user.id)).one()
        print(f"\n{sso}")

SSO()

subscription_type=<SubscriptionType.test: 'test'> disabled=False phone_number=None mfa_secret=None first_name='John' is_mfa_active=False last_name='Doe' created=datetime.datetime(2023, 3, 10, 10, 25, 18) email='john.doe.goyjcmchle@gmail.com' super_user=False is_phone_number_verified=False id=1069 sso_signup_trial_username='goyjcmchle' username='john_doe_goyjcmchle_100' uuid=UUID('4f242577-f3d1-4642-a016-04054475497f') password='$2b$12$X0brzJZ5lPrrKvUc9K4Qw.Kf2Lanysvk4MTKTt630HVcjdXGib04m'

disabled=False sso_email='random_email_id@mail.com' user_id=1069 sso_provider=<SSOProvider.google: 'google'> id=391


In [None]:
# Update existing user details. Only the sso_signup_trial_username field should be updated

with create_sso_user() as user_and_session:
    existing_test_sso_enabled_user = user_and_session[0]
    session = user_and_session[1]

    existing_test_user_info_from_provider = {
        "sub": "10111231231233982347423",
        "name": f"John Doe {existing_test_sso_enabled_user.username}",
        "given_name": "John",
        "family_name": "Doe",
        "picture": "https://lh3.googleusercontent.com/a/someRandomString=s96-c",
        "email": f"john.doe.{existing_test_sso_enabled_user.username}@gmail.com",
        "email_verified": True,
        "locale": "en",
    }

    update_user_info_in_db(existing_test_sso_enabled_user.username, existing_test_user_info_from_provider)

    with get_session_with_context() as session:
        existing_user = session.exec(select(User).where(User.sso_signup_trial_username == existing_test_sso_enabled_user.username)).one()

        assert existing_test_user_info_from_provider["name"].replace(" ", "_").lower() in existing_user.username, existing_user.username
        assert existing_user.first_name == existing_test_user_info_from_provider["given_name"], existing_user.first_name
        assert existing_user.last_name == existing_test_user_info_from_provider["family_name"], existing_user.last_name
        assert existing_user.email == existing_test_user_info_from_provider["email"], existing_user.email
        assert existing_user.sso_signup_trial_username == existing_test_sso_enabled_user.username, existing_user.sso_signup_trial_username

# existing user logging next time with same SSO email

with create_sso_user() as user_and_session:
    test_user = user_and_session[0]
    session = user_and_session[1]
    print(test_user.username)
    update_user_info_in_db(test_user.username, existing_test_user_info_from_provider)

    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.sso_signup_trial_username == test_user.username)).one()

        assert existing_test_user_info_from_provider["name"].replace(" ", "_").lower() in user.username, user.username
        assert user.first_name == existing_test_user_info_from_provider["given_name"]
        assert user.last_name == existing_test_user_info_from_provider["family_name"]
        assert user.email == existing_test_user_info_from_provider["email"]
        assert user.sso_signup_trial_username == test_user.username

    assert user.username == existing_user.username
    assert user.first_name == existing_user.first_name
    assert user.last_name == existing_user.last_name
    assert user.email == existing_user.email
    assert user.sso_signup_trial_username != existing_user.sso_signup_trial_username

    print(user.sso_signup_trial_username)
    
with get_session_with_context() as session:
    test_user = session.exec(select(User).where(User.username == test_user.username)).one()
    print(f"{test_user.disabled=}")
    assert test_user.disabled

SSO()

SSO()

eitjxjchay
eitjxjchay
test_user.disabled=True


In [None]:
# | export


def validate_sso_response(request: Request, sso_provider: str) -> str:
    """Validate the response from the SSO provider

    This function receives the callback from the SSO provider along with the query parameters
    and validates it. Finally the corresponding SSO authentication status and the
    message is stored in the database

    Args:
        request: The callback request object
        sso_provider: The Name of the SSO provider

    Returns:
        The text message indicating the status of the SSO authentication to display in the browser.
    """
    sso_protocol_error = None
    is_sso_successful = False

    # Step 1: get username and nonce from response's state
    try:
        state = request.query_params["state"]
        nonce, username = state.split("_", 1)
    except KeyError as e:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["SSO_CSRF_WARNING"],
        )

    # Step 2: Validate and get record from sso protocol table
    sso_protocol, sso_email = get_sso_protocol_and_email(username, nonce, sso_provider)

    try:
        # Step 3: Check if SESSION_TIME_LIMIT exceeded
        if (datetime.utcnow() - sso_protocol.created_at) > timedelta(
            minutes=SESSION_TIME_LIMIT
        ):
            sso_protocol_error = ERRORS["SSO_SESSION_EXPIRED"]
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail=ERRORS["SSO_SESSION_EXPIRED"],
            )

        # Step 4: get email from SSO provider and validate against our records
        user_info_from_provider = get_user_info_from_provider(
            url=str(request.url),
            nonce_with_username=state,
            sso_provider=sso_provider,
        )
        email_from_provider: str = (
            user_info_from_provider["email"] # type: ignore
            if sso_provider == "google"
            else [email["email"] for email in email_from_provider if email["primary"]][ # type: ignore
                0
            ]
        )
        if email_from_provider != sso_email and "captn_trial" not in username:
            sso_protocol_error = ERRORS["SSO_EMAIL_NOT_SAME"]
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail=ERRORS["SSO_EMAIL_NOT_SAME"],
            )
    finally:
        # Step 5: Update error and is_sso_successful in database
        if sso_protocol_error is None:
            is_sso_successful = True

        with get_session_with_context() as session:
            sso_protocol.error = sso_protocol_error
            sso_protocol.is_sso_successful = is_sso_successful
            session.add(sso_protocol)
            session.commit()

    if "captn_trial" in username:
        update_user_info_in_db(username, user_info_from_provider)  # type: ignore

    return SSO_SUCCESS_MSG

In [None]:
# Negative Scenario: state variable not present in the query param
scope = {
    "type": "http",
    "query_string": {
        "authuser": "0",
        "prompt": "consent",
    },
}
r = Request(scope)
with pytest.raises(HTTPException) as e:
    validate_sso_response(r, sso_provider="google")
assert "Request check failed" in e.value.detail, e.value.detail
e.value.detail

'Request check failed: State not equal in request and response. For your protection, access to this resource is secured against CSRF. Please re-generate the authentication URL and initiate the SSO login process again.'

In [None]:
# Negative Scenario: passing invalid user
scope = {
    "type": "http",
    "query_string": {
        "state": "dkco0tuhs_user_name",
        "code": "come_code",
        "scope": "email https://www.googleapis.com/auth/userinfo.email openid",
        "authuser": "0",
        "prompt": "consent",
    },
}
r = Request(scope)
with pytest.raises(HTTPException) as e:
    validate_sso_response(r, sso_provider="google")
assert "Incorrect username" in e.value.detail
e.value.detail

'Incorrect username. Please try again.'

In [None]:
# Negative Scenario: passing disabled sso provider
with create_sso_user(sso_provider="github") as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    scope = {
        "type": "http",
        "query_string": {
            "state": f"dkco0tuhs_{user.username}",
            "code": "come_code",
            "scope": "email https://www.googleapis.com/auth/userinfo.email openid",
            "authuser": "0",
            "prompt": "consent",
        },
    }
    r = Request(scope)
    with pytest.raises(HTTPException) as e:
        validate_sso_response(r, sso_provider="google")
assert "SSO is not enabled for the provider" in e.value.detail
e.value.detail

SSO()

'SSO is not enabled for the provider.'

In [None]:
# Negative Scenario: CSRF
with create_sso_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    scope = {
        "type": "http",
        "query_string": {
            "state": f"dkco0tuhs_{user.username}",
            "code": "come_code",
            "scope": "email https://www.googleapis.com/auth/userinfo.email openid",
            "authuser": "0",
            "prompt": "consent",
        },
    }
    r = Request(scope)
    with pytest.raises(HTTPException) as e:
        validate_sso_response(r, sso_provider="google")
assert "CSRF" in e.value.detail
e.value.detail

SSO()

'Request check failed: State not equal in request and response. For your protection, access to this resource is secured against CSRF. Please re-generate the authentication URL and initiate the SSO login process again.'