In [None]:
#| default_exp auth

In [None]:
from airt.testing import activate_by_import

[INFO] airt.testing.activate_by_import: Testing environment activated.
[INFO] numexpr.utils: Note: NumExpr detected 64 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
[INFO] numexpr.utils: NumExpr defaulting to 8 threads.
[INFO] airt.keras.helpers: Using a single GPU #0 with memory_limit 1024 MB


In [None]:
#| export


import json
import uuid
from datetime import datetime, timedelta
from os import environ
import secrets
from typing import *
from urllib.parse import urlparse, parse_qs


# from fastcore.foundation import patch
from fastapi import APIRouter, Request
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from fastapi import Depends, HTTPException, status, Query
from pydantic import BaseModel
from jose import JWTError, jwt
from sqlalchemy.exc import NoResultFound, MultipleResultsFound
from sqlmodel import Session, select

from airt.logger import get_logger
from airt.patching import patch

import airt_service.sanitizer
from airt_service.db.models import get_session, get_session_with_context
from airt_service.db.models import APIKeyCreate, APIKey, APIKeyRead, User, UserRead, SSO
from airt_service.errors import HTTPError, ERRORS
from airt_service.helpers import commit_or_rollback, verify_password
from airt_service.totp import validate_totp, require_otp_if_mfa_enabled
from airt_service.sso import (
    SSOAuthURL,
    get_valid_sso_providers,
    initiate_sso_flow,
    get_sso_if_enabled_for_user,
    validate_sso_response,
    get_sso_protocol_and_email,
    SESSION_TIME_LIMIT,
)
from airt_service.sms_utils import validate_otp

In [None]:
import time
from copy import deepcopy
import random
import string
from contextlib import contextmanager

import pytest
from fastapi import Request
from starlette.datastructures import Headers
import pyotp

from airt_service.db.models import (
    create_user_for_testing,
    get_session_with_context,
    SSOProtocol,
)
from airt_service.totp import generate_mfa_provisioning_url, generate_mfa_secret
from airt_service.users import (
    generate_mfa_url,
    activate_mfa,
    ActivateMFARequest,
    disable_mfa,
    enable_sso,
    EnableSSORequest,
    disable_sso,
)

[INFO] airt.executor.subcommand: Module loaded.


In [None]:
#| exporti

logger = get_logger(__name__)

In [None]:
test_username = create_user_for_testing()
display(test_username)

'svwctktgmt'

In [None]:
INVALID_UUID_FOR_TESTING = "00000000-0000-0000-0000-000000000000"

In [None]:
#| export


ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 180 * 24 * 60  # Expire after 180 days

In [None]:
#| export


oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

In [None]:
#| export


def get_user(username: str) -> Optional[User]:
    """Get the user object for the given username

    Args:
        username: Username as a string

    Returns:
        The user object if username is valid else None
    """
    with get_session_with_context() as session:
        try:
            user = session.exec(select(User).where(User.username == username)).one()
        except NoResultFound:
            user = None
    return user

In [None]:
actual = get_user(username=test_username)
display(actual)
assert actual.username == test_username

actual = get_user(username="username_does_not_exists")
assert actual is None

User(id=11, uuid=UUID('d49d8a24-f5a5-4615-ba80-175904db5d5c'), username='svwctktgmt', first_name='unittest', last_name='user', email='svwctktgmt@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2022, 10, 21, 5, 14, 17), phone_number=None, is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)

In [None]:
#| export


def get_password_and_otp_from_json(password: str) -> Tuple[str, str]:
    """Get password and otp

    Args:
        password: password from form_data

    Returns:
        The password and otp as Tuple
    """
    try:
        password_dict = json.loads(password)
        password = password_dict["password"]
        user_otp = password_dict["user_otp"]

    except json.decoder.JSONDecodeError as e:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST, detail=ERRORS["OTP_REQUIRED"]
        )

    return password, user_otp

In [None]:
r = json.dumps({"password": "random password", "user_otp": "123123"})
display(get_password_and_otp_from_json(r))
assert get_password_and_otp_from_json(r) == ("random password", "123123")

r = "random password"
with pytest.raises(HTTPException) as e:
    get_password_and_otp_from_json(r)

('random password', '123123')

In [None]:
#| export


def authenticate_user(username: str, password: str) -> Optional[User]:
    """Validate if the password matches the user's previously stored password.

    In case of MFA activated user, the passed OTP is also gets validated against the current time OTP.

    Args:
        username: Username of the user
        password: Password to validate

    Returns:
        The user object if the credentials matches else None

    Raises:
        HTTPException: If the OTP is invalid for the mfa activated user.
    """
    user = get_user(username)
    if user is None:
        return None

    if user.is_mfa_active:
        password, otp_or_totp = get_password_and_otp_from_json(password)

    if not verify_password(password, user.password):
        return None

    if user.is_mfa_active:
        try:
            validate_totp(user.mfa_secret, otp_or_totp)  # type: ignore
        except HTTPException as e:
            try:
                validate_otp(
                    user=user,
                    otp=otp_or_totp,
                    message_template_name="get_token",
                    session=next(get_session()),
                )
            except HTTPException as e:
                raise HTTPException(
                    status_code=status.HTTP_400_BAD_REQUEST,
                    detail=ERRORS["INVALID_OTP"],
                )

    return user

In [None]:
assert not authenticate_user(
    username="username_does_not_exists", password="password_is_wrong"
)
assert not authenticate_user(username=test_username, password="password_is_wrong")

actual = authenticate_user(
    username=test_username, password=environ["AIRT_SERVICE_SUPER_USER_PASSWORD"]
)
display(actual)
assert actual.username == test_username

User(id=11, uuid=UUID('d49d8a24-f5a5-4615-ba80-175904db5d5c'), username='svwctktgmt', first_name='unittest', last_name='user', email='svwctktgmt@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2022, 10, 21, 5, 14, 17), phone_number=None, is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)

In [None]:
#| export


def create_access_token(data: dict, expire: Optional[datetime] = None) -> str:
    """Create new jwt access token

    Args:
        data: Data to encode in jwt access token
        expire: Expiry datetime of jwt access token

    Returns:
        The encoded jwt access token
    """
    to_encode = data.copy()
    if expire:
        to_encode.update({"exp": expire})

    encoded_jwt = jwt.encode(
        to_encode,
        # nosemgrep: python.jwt.security.jwt-hardcode.jwt-python-hardcoded-secret
        environ["AIRT_TOKEN_SECRET_KEY"],
        algorithm=ALGORITHM,
    )
    return encoded_jwt

In [None]:
expire = datetime.utcnow() + timedelta(minutes=15)
actual = create_access_token(data={"sub": test_username}, expire=expire)
payload = jwt.decode(actual, environ["AIRT_TOKEN_SECRET_KEY"], algorithms=[ALGORITHM])
display(payload)
assert payload["sub"] == test_username
actual = datetime.fromtimestamp(payload["exp"])
assert actual == expire.replace(microsecond=0)

{'sub': 'svwctktgmt', 'exp': 1666330157}

In [None]:
def generate_random_name(size=15, chars=string.ascii_uppercase + string.digits):
    return "".join(random.choice(chars) for _ in range(size))


assert len(generate_random_name()) == 15
assert type(generate_random_name()) == str

In [None]:
#| export

auth_router = APIRouter(
    responses={
        500: {
            "model": HTTPError,
            "description": ERRORS["INTERNAL_SERVER_ERROR"],
        }
    }
)

In [None]:
#| export


class Token(BaseModel):
    """A base class for creating and managing Access token

    Args:
        access_token: Access token
        token_type: Type of the token (bearer token is the only one currently supported)
    """

    access_token: str
    token_type: str

In [None]:
#| exporti


def generate_token(username: str) -> Token:
    """Generate access token

    Args:
        username: Username as a string

    Returns:
        The generated access token
    """
    access_token_expires = datetime.utcnow() + timedelta(
        minutes=ACCESS_TOKEN_EXPIRE_MINUTES
    )
    access_token = create_access_token(
        data={"sub": username}, expire=access_token_expires  # type: ignore
    )

    # Sast recongnizes "bearer" string as hardcoded password but it is not. So using nosec B106.
    return Token(access_token=access_token, token_type="bearer")  # nosec B106

In [None]:
username = "random_username"
actual = generate_token(username)

display(f'access_token: {"*"*len(actual.access_token)}')
assert actual.access_token
assert actual.token_type == "bearer"

'access_token: *****************************************************************************************************************************************'

In [None]:
#| export


@auth_router.post(
    "/token",
    response_model=Token,
    responses={
        400: {
            "model": HTTPError,
            "description": ERRORS["INCORRECT_USERNAME_OR_PASSWORD"],
        }
    },
)
def login_for_access_token(
    form_data: OAuth2PasswordRequestForm = Depends(),
) -> Token:
    """Authenticate with credentials"""

    user = authenticate_user(form_data.username, form_data.password)
    if user is None:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["INCORRECT_USERNAME_OR_PASSWORD"],
            headers={"WWW-Authenticate": "Bearer"},
        )

    token = generate_token(user.username)
    return token

In [None]:
# context manager to create a user for testing


@contextmanager
def create_test_user():
    with get_session_with_context() as session:
        username = create_user_for_testing()
        try:
            user = session.exec(select(User).where(User.username == username)).one()
            yield user, session
        finally:
            # deactivate the user
            with commit_or_rollback(session):
                user.disabled = True
                session.add(user)
                session.commit()
            assert user.disabled

In [None]:
with create_test_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    # Negative Scenario: non-MFA user passing wrong password
    with pytest.raises(HTTPException):
        login_for_access_token(
            form_data=OAuth2PasswordRequestForm(
                username=user.username, password="wrong password", scope="scope"
            )
        )

    # Positive Scenario: non-MFA user passing valid password
    actual = login_for_access_token(
        form_data=OAuth2PasswordRequestForm(
            username=user.username,
            password=environ["AIRT_SERVICE_SUPER_USER_PASSWORD"],
            scope="scope",
        )
    )

display(f'access_token: {"*"*len(actual.access_token)}')
assert actual.access_token
assert actual.token_type == "bearer"

'access_token: ***********************************************************************************************************************************'

In [None]:
# Negative Scenario: non-MFA user passing otp
with create_test_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    random_otp = "123456"
    with pytest.raises(HTTPException) as e:
        login_for_access_token(
            form_data=OAuth2PasswordRequestForm(
                username=user.username,
                password=json.dumps(
                    {
                        "password": environ["AIRT_SERVICE_SUPER_USER_PASSWORD"],
                        "user_otp": random_otp,
                    }
                ),
                scope="scope",
            )
        )
    display(e.value.detail)

'Incorrect username or password. Please try again.'

In [None]:
# Context manager to create MFA enabled user


@contextmanager
def create_mfa_enabled_user():
    with get_session_with_context() as session:
        mfa_enabled_user = create_user_for_testing()
        user = session.exec(select(User).where(User.username == mfa_enabled_user)).one()
        try:
            # generate MFA
            actual = generate_mfa_url(user=user, session=session)
            assert user.mfa_secret is not None
            # activate MFA
            activate_mfa_request = ActivateMFARequest(
                user_otp=pyotp.TOTP(user.mfa_secret).now()
            )
            actual = activate_mfa(
                activate_mfa_request=activate_mfa_request, user=user, session=session
            )
            yield user, session
        finally:
            # deactivate MFA
            user = disable_mfa(
                user_uuid_or_name=user.username,
                otp=pyotp.TOTP(user.mfa_secret).now(),
                user=user,
                session=session,
            )
            with commit_or_rollback(session):
                user.disabled = True
                session.add(user)
            assert user.mfa_secret is None
            assert user.disabled

In [None]:
# Positive scenario: MFA user passing valid otp
with create_mfa_enabled_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    actual = login_for_access_token(
        form_data=OAuth2PasswordRequestForm(
            username=user.username,
            password=json.dumps(
                {
                    "password": environ["AIRT_SERVICE_SUPER_USER_PASSWORD"],
                    "user_otp": pyotp.TOTP(user.mfa_secret).now(),
                }
            ),
            scope="scope",
        )
    )

display(f'access_token: {"*"*len(actual.access_token)}')
assert actual.access_token
assert actual.token_type == "bearer"

'access_token: ***********************************************************************************************************************************'

In [None]:
# Negative scenario: MFA user not passing otp
with create_mfa_enabled_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    with pytest.raises(HTTPException) as e:
        login_for_access_token(
            form_data=OAuth2PasswordRequestForm(
                username=user.username,
                password=environ["AIRT_SERVICE_SUPER_USER_PASSWORD"],
                scope="scope",
            )
        )

assert "OTP is required" in e.value.detail
display(e.value.detail)

'OTP is required. Please enter the OTP generated by the authenticator app or the one you requested via SMS.'

In [None]:
# Negative scenario: MFA user passing random otp
with create_mfa_enabled_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    random_otp = "123456"
    with pytest.raises(HTTPException) as e:
        login_for_access_token(
            form_data=OAuth2PasswordRequestForm(
                username=user.username,
                password=json.dumps(
                    {
                        "password": environ["AIRT_SERVICE_SUPER_USER_PASSWORD"],
                        "user_otp": random_otp,
                    }
                ),
                scope="scope",
            )
        )

display(e.value.detail)

'Invalid OTP. Please try again.'

In [None]:
#| export


class SSOInitiateRequest(BaseModel):
    """A base class for initiating SSO for the provider

    Args:
        username: Username as a string
        password: password as a string
        sso_provider: The name of the sso provider
    """

    username: str
    password: str
    sso_provider: str

In [None]:
#| export


@auth_router.post(
    "/sso/initiate",
    response_model=SSOAuthURL,
    responses={
        400: {
            "model": HTTPError,
            "description": ERRORS["INCORRECT_USERNAME_OR_PASSWORD"],
        }
    },
)
def login_for_sso_access_token(
    sso_initiate_request: SSOInitiateRequest,
) -> SSOAuthURL:
    """Initiate the SSO authentication"""
    user = authenticate_user(
        sso_initiate_request.username, sso_initiate_request.password
    )
    if user is None:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["INCORRECT_USERNAME_OR_PASSWORD"],
        )

    sso_provider = sso_initiate_request.sso_provider
    valid_sso_providers = get_valid_sso_providers()
    if sso_provider not in valid_sso_providers:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=f'{ERRORS["INVALID_SSO_PROVIDER"]}: {valid_sso_providers}',
        )

    sso = get_sso_if_enabled_for_user(user.username, sso_provider)
    if sso is None:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["SSO_NOT_ENABLED_FOR_SERVICE"],
        )
    return initiate_sso_flow(
        username=user.username,
        sso_provider=sso_initiate_request.sso_provider,
        nonce=secrets.token_hex(),
        sso=sso,
    )

In [None]:
# Positive Scenario: MFA and SSO enabled user getting authorization_url
with create_mfa_enabled_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    enable_sso_request = EnableSSORequest(
        sso_provider="google",
        sso_email="random_email_id@mail.com",
        otp=pyotp.TOTP(user.mfa_secret).now(),
    )
    enable_sso(enable_sso_request=enable_sso_request, user=user, session=session)

    sso_initiate_request = SSOInitiateRequest(
        username=user.username,
        password=json.dumps(
            {
                "password": environ["AIRT_SERVICE_SUPER_USER_PASSWORD"],
                "user_otp": pyotp.TOTP(user.mfa_secret).now(),
            }
        ),
        sso_provider="google",
    )

    actual = login_for_sso_access_token(sso_initiate_request)
    display(actual)

SSOAuthURL(authorization_url='https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=842138153914-6kvm51cpin7iocg3nrsnl44s3d24u047.apps.googleusercontent.com&redirect_uri=http%3A%2F%2F127.0.0.1%3A6006%2Fsso%2Fcallback&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email+openid&state=62d3659f16e5f521db045bcd9d2a5c2a9332f9233e695df7a5b0104571cc4c3c_kqgrfchiuw&prompt=select_account')

In [None]:
# Negative Scenario: MFA and SSO enabled user passing invalid OTP
with create_mfa_enabled_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    enable_sso_request = EnableSSORequest(
        sso_provider="google",
        sso_email="random_email_id@mail.com",
        otp=pyotp.TOTP(user.mfa_secret).now(),
    )
    enable_sso(enable_sso_request=enable_sso_request, user=user, session=session)
    invalid_otp = "123456"
    sso_initiate_request = SSOInitiateRequest(
        username=user.username,
        password=json.dumps(
            {
                "password": environ["AIRT_SERVICE_SUPER_USER_PASSWORD"],
                "user_otp": invalid_otp,
            }
        ),
        sso_provider="google",
    )
    with pytest.raises(HTTPException) as e:
        login_for_sso_access_token(sso_initiate_request)
display(e.value.detail)

'Invalid OTP. Please try again.'

In [None]:
# Negative Scenario: Non-MFA user sending OTP
with create_test_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    enable_sso_request = EnableSSORequest(
        sso_provider="google", sso_email="random_email_id@mail.com"
    )
    enable_sso(enable_sso_request=enable_sso_request, user=user, session=session)
    invalid_otp = "123456"
    sso_initiate_request = SSOInitiateRequest(
        username=user.username,
        password=json.dumps(
            {
                "password": environ["AIRT_SERVICE_SUPER_USER_PASSWORD"],
                "user_otp": invalid_otp,
            }
        ),
        sso_provider="google",
    )
    with pytest.raises(HTTPException) as e:
        login_for_sso_access_token(sso_initiate_request)
display(e.value.detail)

'Incorrect username or password. Please try again.'

In [None]:
# context manager to create a SSO enabled user for GOOGLE


@contextmanager
def create_sso_user(
    sso_provider: str = "google", sso_email: str = "random_email_id@mail.com"
):
    with get_session_with_context() as session:
        sso_enabled_user = create_user_for_testing()
        user = session.exec(select(User).where(User.username == sso_enabled_user)).one()
        try:
            enable_sso_request = EnableSSORequest(
                sso_provider=sso_provider, sso_email=sso_email
            )
            actual = enable_sso(
                enable_sso_request=enable_sso_request, user=user, session=session
            )
            display(actual)
            yield user, session
        finally:
            # deactivate User
            with commit_or_rollback(session):
                user.disabled = True
                session.add(user)
                session.commit()
            assert user.disabled

In [None]:
# Negative scenario: Passing invalid SSO provider
with create_sso_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]

    sso_initiate_request = SSOInitiateRequest(
        username=user.username,
        password=environ["AIRT_SERVICE_SUPER_USER_PASSWORD"],
        sso_provider="invalid_sso_provider",
    )

    with pytest.raises(HTTPException) as e:
        login_for_sso_access_token(sso_initiate_request)
display(e.value.detail)

SSO()

"Invalid SSO provider. Valid SSO providers are: ['google', 'github']"

In [None]:
# Negative scenario: Passing an SSO provider which is not yet enabled
with create_sso_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]

    sso_initiate_request = SSOInitiateRequest(
        username=user.username,
        password=environ["AIRT_SERVICE_SUPER_USER_PASSWORD"],
        sso_provider="github",
    )

    with pytest.raises(HTTPException) as e:
        login_for_sso_access_token(sso_initiate_request)
display(e.value.detail)

SSO()

'SSO is not enabled for the provider.'

In [None]:
# Negative scenario: Passing an disabled SSO provider
with create_sso_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]

    disable_sso(
        user_uuid_or_name=str(user.uuid),
        sso_provider="google",
        user=user,
        session=session,
    )

    sso_initiate_request = SSOInitiateRequest(
        username=user.username,
        password=environ["AIRT_SERVICE_SUPER_USER_PASSWORD"],
        sso_provider="google",
    )

    with pytest.raises(HTTPException) as e:
        login_for_sso_access_token(sso_initiate_request)
display(e.value.detail)

SSO()

'SSO is not enabled for the provider.'

In [None]:
# Positive scenario
with create_sso_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]

    sso_initiate_request = SSOInitiateRequest(
        username=user.username,
        password=environ["AIRT_SERVICE_SUPER_USER_PASSWORD"],
        sso_provider="google",
    )

    actual = login_for_sso_access_token(sso_initiate_request)
    display(actual.authorization_url)
    assert user.username in actual.authorization_url

SSO()

'https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=842138153914-6kvm51cpin7iocg3nrsnl44s3d24u047.apps.googleusercontent.com&redirect_uri=http%3A%2F%2F127.0.0.1%3A6006%2Fsso%2Fcallback&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email+openid&state=d20f215764c0e2c3d6b0a9ac811017cb916b1dd65b582fbad842cacb2265aec4_seucxcaorp&prompt=select_account'

In [None]:
#| export


@auth_router.get("/sso/callback")
def sso_google_callback(request: Request) -> str:
    """SSO callback route"""

    sso_provider = "google" if "googleapis" in str(request.url) else "github"
    return validate_sso_response(request=request, sso_provider=sso_provider)

In [None]:
#| export


@auth_router.get(
    "/sso/token",
    responses={
        400: {
            "model": HTTPError,
            "description": ERRORS["INCORRECT_USER_USERNAME"],
        }
    },
)
def finish_sso_flow(authorization_url: str) -> Token:
    """Finish SSO flow"""
    sso_provider = "google" if "accounts.google.com" in authorization_url else "github"

    state = parse_qs(urlparse(authorization_url).query)["state"][0]
    nonce, username = state.split("_", 1)

    sso_protocol, _ = get_sso_protocol_and_email(username, nonce, sso_provider)

    # https://stackoverflow.com/questions/3297048/403-forbidden-vs-401-unauthorized-http-responses
    if (datetime.utcnow() - sso_protocol.created_at) > timedelta(
        minutes=SESSION_TIME_LIMIT
    ):
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERRORS["SSO_SESSION_EXPIRED"],
        )
    if not sso_protocol.is_sso_successful:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERRORS["SSO_NOT_YET_FINISHED"],
        )
    else:
        token = generate_token(username)
        return token

In [None]:
# Negative Scenario: passing wrong username
test_auth_url = "https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=random_id&redirect_uri=http%3A%2F%2Frandom&state=randomNonce_random_user&prompt=select_account"
with pytest.raises(HTTPException) as e:
    finish_sso_flow(test_auth_url)
assert "Incorrect username" in e.value.detail
e.value.detail

'Incorrect username. Please try again.'

In [None]:
# Negative Scenario: sso enabled for google but passing github

with create_sso_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    test_auth_url = f"https://github.com/o/oauth2/v2/auth?response_type=code&client_id=random_id&redirect_uri=http%3A%2F%2Frandom&state=randomNonce_{user.username}&prompt=select_account"
    with pytest.raises(HTTPException) as e:
        finish_sso_flow(test_auth_url)
    assert "SSO is not enabled for the provider" in e.value.detail
e.value.detail

SSO()

'SSO is not enabled for the provider.'

In [None]:
# Negative Scenario: passing invalid nunce

with create_sso_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    test_auth_url = f"https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=random_id&redirect_uri=http%3A%2F%2Frandom&state=randomNonce_{user.username}&prompt=select_account"
    with pytest.raises(HTTPException) as e:
        finish_sso_flow(test_auth_url)
    assert "Request check failed:" in e.value.detail
e.value.detail

SSO()

'Request check failed: State not equal in request and response. For your protection, access to this resource is secured against CSRF. Please re-generate the authentication URL and initiate the SSO login process again.'

In [None]:
# Negative Scenario: Initiated the SSO process but didn't yet completed

sso_provider = "google"
sso_email = "random.mail@gmail.com"

with create_sso_user(
    sso_provider=sso_provider, sso_email=sso_email
) as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]

    sso_initiate_request = SSOInitiateRequest(
        username=user.username,
        password=environ["AIRT_SERVICE_SUPER_USER_PASSWORD"],
        sso_provider=sso_provider,
    )
    actual = login_for_sso_access_token(sso_initiate_request)
    display(actual.authorization_url)
    assert user.username in actual.authorization_url

    with pytest.raises(HTTPException) as e:
        finish_sso_flow(actual.authorization_url)
    assert "SSO authentication is not complete" in e.value.detail
e.value.detail

SSO()

'https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=842138153914-6kvm51cpin7iocg3nrsnl44s3d24u047.apps.googleusercontent.com&redirect_uri=http%3A%2F%2F127.0.0.1%3A6006%2Fsso%2Fcallback&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email+openid&state=8493faab48af050dfdacd9ed1a80e76ffcee1eb946d5788c7e42e1967e1ce257_lnyegsuozu&prompt=select_account'

'SSO authentication is not complete. Please click on the authentication link you have received while requesting a new token and complete the login process first.'

In [None]:
#| exporti

get_apikey_responses = {
    400: {"model": HTTPError, "description": ERRORS["APIKEY_REVOKED"]},
    401: {"model": HTTPError, "description": ERRORS["INCORRECT_APIKEY"]},
}


@patch(cls_method=True)
def get(cls: APIKey, key_uuid_or_name: str, user: User, session: Session) -> APIKey:
    """Function to get APIKey object

    Args:
        key_uuid_or_name: UUID/Name of the APIKey object
        user: User object
        session: Sqlmodel session

    Returns:
        The APIKey object

    Raises:
        HTTPException: if the key UUID/name is invalid or not enough authorization to access apikey object
    """

    try:
        key_uuid_or_name = uuid.UUID(key_uuid_or_name)  # type: ignore
        statement = select(APIKey).where(
            APIKey.uuid == key_uuid_or_name, APIKey.user == user
        )
    except ValueError:
        statement = select(APIKey).where(
            APIKey.name == key_uuid_or_name, APIKey.user == user
        )

    try:
        apikey = session.exec(statement).one()

    except MultipleResultsFound:
        try:
            # ignoring revoked keys from the results
            apikey = session.exec(statement.where(APIKey.disabled == False)).one()
        except NoResultFound:
            # if all the keys matching the name are disabled
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST, detail=ERRORS["APIKEY_REVOKED"]
            )

    except NoResultFound:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERRORS["INCORRECT_APIKEY"],
        )

    if apikey.disabled:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST, detail=ERRORS["APIKEY_REVOKED"]
        )
    return apikey

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    # Positive case: Getting details of a valid key by passing key_id
    apikey = APIKey(user=user)
    session.add(apikey)
    session.commit()
    session.refresh(apikey)
    actual = APIKey.get(key_uuid_or_name=str(apikey.uuid), user=user, session=session)
    display(actual)
    assert actual == apikey

    # Positive case: Getting details of a valid key by passing key_name
    key_name = generate_random_name()
    apikey = APIKey(name=key_name, user=user)
    session.add(apikey)
    session.commit()
    session.refresh(apikey)

    actual = APIKey.get(key_uuid_or_name=apikey.name, user=user, session=session)
    display(actual)
    assert actual == apikey

    # Negative case: Getting details of an invalid key_id
    with pytest.raises(HTTPException) as e:
        APIKey.get(
            key_uuid_or_name=INVALID_UUID_FOR_TESTING, user=user, session=session
        )
    display(e)

    # Negative case: Getting details of an invalid key_name
    with pytest.raises(HTTPException) as e:
        APIKey.get(key_uuid_or_name="random-name", user=user, session=session)
    display(e)

    # Negative case: Getting details of the revoked key
    apikey_disabled = APIKey(user=user, disabled=True)
    session.add(apikey_disabled)
    session.commit()
    session.refresh(apikey_disabled)
    with pytest.raises(HTTPException) as e:
        APIKey.get(
            key_uuid_or_name=str(apikey_disabled.uuid), user=user, session=session
        )
    display(e)

APIKey(name=None, expiry=None, uuid=UUID('4827718d-592a-4aad-91cf-bea09ea0d5d0'), created=datetime.datetime(2022, 10, 21, 5, 14, 26), id=1, disabled=False, user_id=11)

APIKey(name='AT6WKCBIWVXUGRG', expiry=None, uuid=UUID('3727b0b3-8dc7-44f7-a15e-60fad3d0a8a7'), created=datetime.datetime(2022, 10, 21, 5, 14, 26), id=2, disabled=False, user_id=11)

<ExceptionInfo HTTPException(status_code=401, detail='No such apikey or not enough authorization to access the apikey.') tblen=2>

<ExceptionInfo HTTPException(status_code=401, detail='No such apikey or not enough authorization to access the apikey.') tblen=2>

<ExceptionInfo HTTPException(status_code=400, detail='The Apikey has already been revoked.') tblen=2>

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    # Positive case: Getting details of the new key created with a revoked key name
    key_name_to_revoke = generate_random_name()
    apikey_disabled = APIKey(name=key_name_to_revoke, user=user, disabled=True)
    session.add(apikey_disabled)
    session.commit()
    session.refresh(apikey_disabled)

    # Creating new key with revoked key name
    new_apikey = APIKey(name=key_name_to_revoke, user=user)
    session.add(new_apikey)
    session.commit()
    session.refresh(new_apikey)

    actual = APIKey.get(key_uuid_or_name=new_apikey.name, user=user, session=session)
    assert actual.name == key_name_to_revoke
    display(actual.name)

'JYSTNDVSPKFP6KS'

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    # Positive case: Getting details of the new key created with a revoked key name
    key_name_to_revoke = generate_random_name()
    apikey_disabled = APIKey(name=key_name_to_revoke, user=user, disabled=True)
    session.add(apikey_disabled)
    session.commit()
    session.refresh(apikey_disabled)
    display(apikey_disabled)

    # Creating new key with revoked key name
    new_apikey = APIKey(name=key_name_to_revoke, user=user, disabled=True)
    session.add(new_apikey)
    session.commit()
    session.refresh(new_apikey)
    display(new_apikey)

    assert apikey_disabled.name == new_apikey.name

    with pytest.raises(HTTPException) as e:
        APIKey.get(key_uuid_or_name=new_apikey.name, user=user, session=session)

    assert "Apikey has already been revoked" in (e.value.detail)
    e.value.detail

APIKey(name='4IOPCA4GB4L9K1E', expiry=None, uuid=UUID('25b83beb-11b8-413c-864f-e6dd55ea4356'), created=datetime.datetime(2022, 10, 21, 5, 14, 26), id=6, disabled=True, user_id=11)

APIKey(name='4IOPCA4GB4L9K1E', expiry=None, uuid=UUID('cce0be91-998b-4998-94e5-8734655772e3'), created=datetime.datetime(2022, 10, 21, 5, 14, 26), id=7, disabled=True, user_id=11)

In [None]:
#| export


def get_current_active_user(token: str = Depends(oauth2_scheme)) -> User:
    """Get active user details

    Args:
        token: OAuth token

    Returns:
        The active user details

    Raises:
        HTTPException: if the user is inactive or the token is invalid
    """
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail=ERRORS["INVALID_CREDENTIALS"],
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(
            token,
            # nosemgrep: python.jwt.security.jwt-hardcode.jwt-python-hardcoded-secret
            environ["AIRT_TOKEN_SECRET_KEY"],
            algorithms=[ALGORITHM],
        )
        username: str = payload.get("sub")
        if username is None:
            raise credentials_exception
    except JWTError:
        raise credentials_exception
    user = get_user(username=username)
    if user is None:
        raise credentials_exception
    if user.disabled:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST, detail=ERRORS["INACTIVE_USER"]
        )
    with get_session_with_context() as session:
        user = session.merge(user)
        if "key_uuid" in payload:
            apikey = APIKey.get(key_uuid_or_name=payload["key_uuid"], user=user, session=session)  # type: ignore
    return user  # type: ignore

In [None]:
with pytest.raises(HTTPException):
    get_current_active_user("some token")


token_data = login_for_access_token(
    form_data=OAuth2PasswordRequestForm(
        username=test_username,
        password=environ["AIRT_SERVICE_SUPER_USER_PASSWORD"],
        scope="scope",
    )
)


actual = get_current_active_user(token_data.access_token)
display(actual)
assert actual.username == test_username

User(id=11, uuid=UUID('d49d8a24-f5a5-4615-ba80-175904db5d5c'), username='svwctktgmt', first_name='unittest', last_name='user', email='svwctktgmt@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2022, 10, 21, 5, 14, 17), phone_number=None, is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)

In [None]:
#| exporti


@patch(cls_method=True)
def _create(
    cls: APIKey, apikey_to_create: APIKeyCreate, user: User, session: Session
) -> APIKey:
    """Create APIKey

    Args:
        apikey_to_create: APIKeyCreate object
        user: User object
        session: Sqlmodel session
    Returns:
        The created APIKey object
    """
    with commit_or_rollback(session):
        apikey = APIKey(**apikey_to_create.dict())
        apikey.user = user
        session.add(apikey)
    return apikey

In [None]:
#| export


@auth_router.post("/apikey", response_model=Token)
@require_otp_if_mfa_enabled
def create_apikey(
    apikey_to_create: APIKeyCreate,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> Token:
    """Create apikey"""
    user = session.merge(user)

    key_exists = session.exec(
        select(APIKey)
        .where(APIKey.user == user)
        .where(APIKey.name == apikey_to_create.name)
        .where(APIKey.disabled == False)
    ).first()

    if key_exists is not None:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["APIKEY_NAME_ALREADY_EXISTS"],
        )

    try:
        apikey = APIKey._create(apikey_to_create, user, session)  # type: ignore
        access_token = create_access_token(
            data={"sub": user.username, "key_uuid": str(apikey.uuid)}, expire=apikey.expiry  # type: ignore
        )
    except Exception as e:
        logger.exception(e)
        error_message = (
            e._message() if callable(getattr(e, "_message", None)) else str(e)  # type: ignore
        )
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=error_message,
        )

    # Sast recongnizes "bearer" string as hardcoded password but it is not. So using nosec B106.
    return Token(access_token=access_token, token_type="bearer")  # nosec B106

In [None]:
# MFA enabled user trying to create new api-key
with create_mfa_enabled_user() as user_and_session:
    # Negative scenario: Not passing the OTP in the request
    with pytest.raises(HTTPException) as e:
        apikey_to_create = APIKeyCreate(
            name=generate_random_name(), expiry=datetime.utcnow() + timedelta(days=1)
        )
        create_apikey(
            apikey_to_create=apikey_to_create,
            user=user_and_session[0],
            session=user_and_session[1],
        )
    assert "OTP is required" in str(e.value.detail)
    display(e.value.detail)

    # Negative scenario: passing invalid OTP in the request
    with pytest.raises(HTTPException) as e:
        random_otp = 111111
        apikey_to_create = APIKeyCreate(
            name=generate_random_name(),
            expiry=datetime.utcnow() + timedelta(days=1),
            otp=random_otp,
        )
        create_apikey(
            apikey_to_create=apikey_to_create,
            user=user_and_session[0],
            session=user_and_session[1],
        )
    assert "Invalid OTP" in str(e.value.detail), str(e.value.detail)
    display(e.value.detail)

    # Positive scenario: passing valid OTP in the request
    apikey_to_create = APIKeyCreate(
        name=generate_random_name(),
        expiry=datetime.utcnow() + timedelta(days=1),
        otp=pyotp.TOTP(user_and_session[0].mfa_secret).now(),
    )
    actual = create_apikey(
        apikey_to_create=apikey_to_create,
        user=user_and_session[0],
        session=user_and_session[1],
    )
    display(f'access_token: {"*"*len(actual.access_token)}')
    assert actual.access_token
    assert actual.token_type == "bearer"

'OTP is required. Please enter the OTP generated by the authenticator app or the one you requested via SMS.'

'Invalid OTP. Please try again.'

'access_token: *****************************************************************************************************************************************************************************************************'

In [None]:
# Negative Scenario: Non-MFA user trying to create a new api-key by passing OTP
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    with pytest.raises(HTTPException) as e:
        random_otp = 111111
        apikey_to_create = APIKeyCreate(
            name=generate_random_name(),
            expiry=datetime.utcnow() + timedelta(days=1),
            otp=random_otp,
        )
        create_apikey(apikey_to_create=apikey_to_create, user=user, session=session)
        assert (
            str(e.value.detail)
            == "MFA is not activated for the account. Please pass the OTP only after activating the MFA for your account."
        )
    display(e.value.detail)

'MFA is not activated for the account. Please pass the OTP only after activating the MFA for your account.'

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    apikey_to_create = APIKeyCreate(
        name=generate_random_name(), expiry=datetime.utcnow() + timedelta(days=1)
    )
    actual = create_apikey(
        apikey_to_create=apikey_to_create, user=user, session=session
    )

    display(f'access_token: {"*"*len(actual.access_token)}')
    assert actual.access_token
    assert actual.token_type == "bearer"

    actual_user = get_current_active_user(actual.access_token)
    display(actual_user)
    assert actual_user.username == user.username

'access_token: *****************************************************************************************************************************************************************************************************'

User(id=11, uuid=UUID('d49d8a24-f5a5-4615-ba80-175904db5d5c'), username='svwctktgmt', first_name='unittest', last_name='user', email='svwctktgmt@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2022, 10, 21, 5, 14, 17), phone_number=None, is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)

In [None]:
# Test for failure case where length of name is too long for apikey

with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    apikey_to_create = APIKeyCreate(
        name="a" * 1800, expiry=datetime.utcnow() + timedelta(days=1)
    )

    with pytest.raises(HTTPException) as e:
        create_apikey(apikey_to_create=apikey_to_create, user=user, session=session)

    display(e)

[ERROR] __main__: DataError('(MySQLdb.DataError) (1406, "Data too long for column \'name\' at row 1")')
Traceback (most recent call last):
  File "/root/.local/lib/python3.8/site-packages/sqlalchemy/engine/base.py", line 1900, in _execute_context
    self.dialect.do_execute(
  File "/root/.local/lib/python3.8/site-packages/sqlalchemy/engine/default.py", line 736, in do_execute
    cursor.execute(statement, parameters)
  File "/root/.local/lib/python3.8/site-packages/MySQLdb/cursors.py", line 206, in execute
    res = self._query(query)
  File "/root/.local/lib/python3.8/site-packages/MySQLdb/cursors.py", line 319, in _query
    db.query(q)
  File "/root/.local/lib/python3.8/site-packages/MySQLdb/connections.py", line 254, in query
    _mysql.connection.query(self, query)
MySQLdb.DataError: (1406, "Data too long for column 'name' at row 1")

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<ipython-input-54-6168e5ae3cd7>", 

<ExceptionInfo HTTPException(status_code=400, detail='(MySQLdb.DataError) (1406, "Data too long for column \'name\' at row 1")') tblen=3>

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    apikey_to_create = APIKeyCreate(
        name=generate_random_name(), expiry=datetime.utcnow() + timedelta(seconds=5)
    )
    actual = create_apikey(
        apikey_to_create=apikey_to_create, user=user, session=session
    )
    display(f'access_token: {"*"*len(actual.access_token)}')

    time.sleep(10)
    with pytest.raises(HTTPException) as e:
        get_current_active_user(actual.access_token)
    display(e)

'access_token: *****************************************************************************************************************************************************************************************************'

<ExceptionInfo HTTPException(status_code=401, detail='Your credentials could not be validated. The developer token/apikey is invalid or expired.') tblen=2>

In [None]:
# Test for failure case where a api-key is created using existing api-key's name

with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    random_api_key_name = generate_random_name()
    apikey_to_create = APIKeyCreate(
        name=random_api_key_name, expiry=datetime.utcnow() + timedelta(days=1)
    )
    actual = create_apikey(
        apikey_to_create=apikey_to_create, user=user, session=session
    )

    display(f'access_token: {"*"*len(actual.access_token)}')
    assert actual.access_token
    assert actual.token_type == "bearer"

    with pytest.raises(HTTPException) as e:
        apikey_to_create = APIKeyCreate(
            name=random_api_key_name, expiry=datetime.utcnow() + timedelta(days=1)
        )
        create_apikey(apikey_to_create=apikey_to_create, user=user, session=session)

    assert (
        e.value.detail == "An Api-key with the same name already exists"
    ), e.value.detail

'access_token: *****************************************************************************************************************************************************************************************************'

In [None]:
#| export


@auth_router.get(
    "/apikey/{key_uuid_or_name}", response_model=APIKeyRead, responses=get_apikey_responses  # type: ignore
)
def get_details_of_apikey(
    key_uuid_or_name: str,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> APIKey:
    """Get details of the apikey"""
    user = session.merge(user)
    # get details from the internal db for apikey_id
    return APIKey.get(key_uuid_or_name=key_uuid_or_name, user=user, session=session)  # type: ignore

In [None]:
# MFA enabled user trying to get details of the api-key
with create_mfa_enabled_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    # Create a new api-key for testing
    apikey_to_create = APIKeyCreate(
        name=generate_random_name(),
        expiry=100,
        otp=pyotp.TOTP(user.mfa_secret).now(),
    )
    actual = create_apikey(
        apikey_to_create=apikey_to_create,
        user=user,
        session=session,
    )

    # Positive scenario: passing valid OTP in the request
    key_details = get_details_of_apikey(
        key_uuid_or_name=str(user.apikeys[-1].uuid), user=user, session=session
    )
    display(key_details.name)
    assert key_details.name is not None

'CTFIZX5ZVCJ70AK'

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    expected = user.apikeys[-1]

    actual = get_details_of_apikey(
        key_uuid_or_name=str(expected.uuid), user=user, session=session
    )
    display(actual)
    assert actual == expected

APIKey(name='J41CKQH992MA59Z', expiry=datetime.datetime(2022, 10, 22, 5, 14, 36), uuid=UUID('536b91b9-9f6d-4c3c-b531-ffa8795ef84f'), created=datetime.datetime(2022, 10, 21, 5, 14, 36), id=12, disabled=False, user_id=11)

In [None]:
#| export


def get_valid_user(user: User, session: Session, user_uuid_or_name: str) -> User:
    """Get valid user object to perform the operation

    Args:
        user: User object
        session: Sqlmodel session
        user_uuid_or_name: Account user_uuid/username to perform the operation

    Returns:
        User object

    Raises:
        HTTPException: If the user_uuid/username is invalid or the user have insufficient permission to modify other user's data
    """
    try:
        user_uuid_or_name = uuid.UUID(user_uuid_or_name)  # type: ignore
        attribute_to_check = "uuid"
    except ValueError:
        attribute_to_check = "username"

    if getattr(user, attribute_to_check) != user_uuid_or_name and not user.super_user:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERRORS["NOT_ENOUGH_PERMISSION_TO_ACCESS_OTHERS_DATA"],
        )

    try:
        _user = session.exec(
            select(User).where(getattr(User, attribute_to_check) == user_uuid_or_name)
        ).one()

    except NoResultFound:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS[f"INCORRECT_USER_{attribute_to_check.upper()}"],
        )

    return _user

In [None]:
with get_session_with_context() as session:
    user = session.exec(
        select(User).where(User.username == create_user_for_testing())
    ).one()

    # Negative Scenario: Normal user accessing other users account
    with pytest.raises(HTTPException) as e:
        get_valid_user(user, session, INVALID_UUID_FOR_TESTING)
    assert e.value.detail == "Insufficient permission to access other user's data"
    display(e.value.detail)

    with pytest.raises(HTTPException) as e:
        random_user_name = "random_user_name"
        get_valid_user(user, session, random_user_name)
    assert e.value.detail == "Insufficient permission to access other user's data"
    display(e.value.detail)

    # Positive Scenario: Normal user accessing own account
    actual = get_valid_user(user, session, str(user.uuid))
    display(actual)
    actual = get_valid_user(user, session, user.username)
    display(actual)

# Positive Scenario: Super user accessing own account
with get_session_with_context() as session:
    super_user = session.exec(select(User).where(User.username == "kumaran")).one()
    actual = get_valid_user(super_user, session, str(super_user.uuid))
    assert actual.id == super_user.id

    actual = get_valid_user(super_user, session, super_user.username)
    assert actual.uuid == super_user.uuid

    # Positive Scenario: Super user accessing others account
    actual = get_valid_user(super_user, session, str(user.uuid))
    assert actual.uuid == user.uuid
    user.super_user = False

    actual = get_valid_user(super_user, session, user.username)
    assert actual.uuid == user.uuid
    user.super_user = False

    # Negative Scenario: Super user accessing invalid user account
    with pytest.raises(HTTPException) as e:
        get_valid_user(super_user, session, INVALID_UUID_FOR_TESTING)
    assert "The user uuid is incorrect" in e.value.detail
    display(e.value.detail)

    with pytest.raises(HTTPException) as e:
        random_user_name = "random_user_name"
        get_valid_user(super_user, session, random_user_name)
    assert "Incorrect username" in e.value.detail
    display(e.value.detail)

"Insufficient permission to access other user's data"

"Insufficient permission to access other user's data"

User(id=73, uuid=UUID('7edc2abc-e4d3-4c13-b1e9-614a58f25b4a'), username='fhanlxanat', first_name='unittest', last_name='user', email='fhanlxanat@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2022, 10, 21, 5, 14, 37), phone_number=None, is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)

User(id=73, uuid=UUID('7edc2abc-e4d3-4c13-b1e9-614a58f25b4a'), username='fhanlxanat', first_name='unittest', last_name='user', email='fhanlxanat@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2022, 10, 21, 5, 14, 37), phone_number=None, is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)

'The user uuid is incorrect. Please try again.'

'Incorrect username. Please try again.'

In [None]:
#| exporti


@patch
def disable(self: APIKey, session: Session) -> APIKey:
    """Disable an APIKey

    Args:
        session: Sqlmodel session

    Returns:
        The disabled APIKey object
    """
    with commit_or_rollback(session):
        self.disabled = True
        session.add(self)
    return self

In [None]:
#| export


@auth_router.delete(
    "/{user_uuid_or_name}/apikey/{key_uuid_or_name}", response_model=APIKeyRead, responses=get_apikey_responses  # type: ignore
)
@require_otp_if_mfa_enabled
def delete_apikey(
    user_uuid_or_name: str,
    key_uuid_or_name: str,
    otp: Optional[str] = None,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> APIKey:
    """Revoke apikey"""
    user = session.merge(user)
    # get details from the internal db for apikey_id
    apikey = APIKey.get(key_uuid_or_name=key_uuid_or_name, user=get_valid_user(user, session, user_uuid_or_name), session=session)  # type: ignore

    return apikey.disable(session)  # type: ignore

In [None]:
# MFA enabled user trying to revoke the api-key
with create_mfa_enabled_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]

    # Creating new API-key
    random_api_key_name = generate_random_name()
    apikey_to_create = APIKeyCreate(
        name=random_api_key_name,
        expiry=datetime.utcnow() + timedelta(days=1),
        otp=pyotp.TOTP(user.mfa_secret).now(),
    )
    original_api_key = create_apikey(
        apikey_to_create=apikey_to_create, user=user, session=session
    )
    assert original_api_key.access_token
    assert original_api_key.token_type == "bearer"

    # Negative scenario: Not passing the OTP in the request
    with pytest.raises(HTTPException) as e:
        delete_apikey(
            user_uuid_or_name=user.uuid,
            key_uuid_or_name=user.apikeys[-1].name,
            otp=None,
            user=user,
            session=session,
        )
    display(e.value.detail)
    assert "OTP is required" in str(e.value.detail)

    # Negative scenario: passing invalid OTP in the request
    with pytest.raises(HTTPException) as e:
        random_otp = 111111
        delete_apikey(
            user_uuid_or_name=str(user.uuid),
            key_uuid_or_name=user.apikeys[-1].name,
            otp=random_otp,
            user=user,
            session=session,
        )
    assert "Invalid OTP" in str(e.value.detail), str(e.value.detail)
    display(e.value.detail)

    # Positive scenario: passing valid OTP and user UUID in the requestin the request
    valid_otp = pyotp.TOTP(user.mfa_secret).now()
    key_details = delete_apikey(
        user_uuid_or_name=str(user.uuid),
        key_uuid_or_name=user.apikeys[-1].name,
        otp=valid_otp,
        user=user,
        session=session,
    )
    display(key_details.name)
    assert key_details.name is not None

'OTP is required. Please enter the OTP generated by the authenticator app or the one you requested via SMS.'

'Invalid OTP. Please try again.'

'B31YULCZJAIFPFW'

In [None]:
# MFA enabled user trying to revoke the api-key
with create_mfa_enabled_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]

    # Creating new API-key
    random_api_key_name = generate_random_name()
    apikey_to_create = APIKeyCreate(
        name=random_api_key_name,
        expiry=datetime.utcnow() + timedelta(days=1),
        otp=pyotp.TOTP(user.mfa_secret).now(),
    )
    original_api_key = create_apikey(
        apikey_to_create=apikey_to_create, user=user, session=session
    )
    assert original_api_key.access_token
    assert original_api_key.token_type == "bearer"

    # Positive scenario: passing valid OTP and username in the request
    valid_otp = pyotp.TOTP(user.mfa_secret).now()
    key_details = delete_apikey(
        user_uuid_or_name=user.username,
        key_uuid_or_name=user.apikeys[-1].name,
        otp=valid_otp,
        user=user,
        session=session,
    )
    display(key_details.name)
    assert key_details.name is not None

'301XIR2Y5IE861O'

In [None]:
# Negative Scenario: Non-MFA user user trying to revoke the api-key by passing OTP
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    with pytest.raises(HTTPException) as e:
        random_otp = 111111
        delete_apikey(
            user_uuid_or_name=str(user.uuid),
            key_uuid_or_name=user.apikeys[-1].name,
            otp=random_otp,
            user=user,
            session=session,
        )
    display(e.value.detail)
    assert (
        str(e.value.detail)
        == "MFA is not activated for the account. Please pass the OTP only after activating the MFA for your account."
    )

'MFA is not activated for the account. Please pass the OTP only after activating the MFA for your account.'

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    expected = APIKey(user=user)
    session.add(expected)
    session.commit()
    session.refresh(expected)
    display(f"{expected=}")

    actual = delete_apikey(
        user_uuid_or_name=str(user.uuid),
        key_uuid_or_name=str(expected.uuid),
        otp=None,
        user=user,
        session=session,
    )
    display(f"{actual=}")
    assert actual.id == expected.id
    assert actual.disabled

"expected=APIKey(name=None, expiry=None, uuid=UUID('19251603-ede6-473c-9343-34983e76243e'), created=datetime.datetime(2022, 10, 21, 5, 14, 38), id=16, disabled=False, user_id=11)"

'actual=APIKey()'

In [None]:
# Positive case: Create a new api-key using revoked api-key's name

with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    # Creating new API-key
    random_api_key_name = generate_random_name()
    apikey_to_create = APIKeyCreate(
        name=random_api_key_name, expiry=datetime.utcnow() + timedelta(days=1), otp=None
    )
    original_api_key = create_apikey(
        apikey_to_create=apikey_to_create, user=user, session=session
    )

    display(f'access_token: {"*"*len(original_api_key.access_token)}')
    assert original_api_key.access_token
    assert original_api_key.token_type == "bearer"

with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    assert random_api_key_name == user.apikeys[-1].name

    # Revoking the created API-key
    revoked_api_key = delete_apikey(
        user_uuid_or_name=str(user.uuid),
        key_uuid_or_name=user.apikeys[-1].name,
        otp=None,
        user=user,
        session=session,
    )
    assert revoked_api_key.id == user.apikeys[-1].id
    assert revoked_api_key.disabled

    # Creating a new API-key with revoked key name
    new_api_key = APIKeyCreate(
        name=random_api_key_name,
        expiry=datetime.utcnow() + timedelta(days=1),
        otp=None,
    )
    new_key = create_apikey(
        apikey_to_create=apikey_to_create, user=user, session=session
    )

    display(f'access_token: {"*"*len(new_key.access_token)}')
    assert new_key.access_token
    assert new_key.token_type == "bearer"

'access_token: *****************************************************************************************************************************************************************************************************'

'access_token: *****************************************************************************************************************************************************************************************************'

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    # Creating new API-key
    random_api_key_name = generate_random_name()
    apikey_to_create = APIKeyCreate(
        name=random_api_key_name, expiry=datetime.utcnow() + timedelta(days=1), otp=None
    )
    original_api_key = create_apikey(
        apikey_to_create=apikey_to_create, user=user, session=session
    )

    # Super user Trying to revoke other user api-key
    super_user = session.exec(select(User).where(User.username == "kumaran")).one()
    revoked_api_key = delete_apikey(
        user_uuid_or_name=str(user.uuid),
        key_uuid_or_name=str(user.apikeys[-1].uuid),
        otp=None,
        user=super_user,
        session=session,
    )
    assert revoked_api_key.id == user.apikeys[-1].id
    assert revoked_api_key.disabled
    display(revoked_api_key)

APIKey(created=datetime.datetime(2022, 10, 21, 5, 14, 38), name='0ET9ZYZWEOE5NUI', uuid=UUID('a4eab247-783e-43f1-859e-a0c3964e78f0'), expiry=datetime.datetime(2022, 10, 22, 5, 14, 38), id=19, user_id=11, disabled=True)

In [None]:
#| exporti


@patch(cls_method=True)
def get_all(
    cls: APIKey,
    user: User,
    include_disabled: bool,
    offset: int,
    limit: int,
    session: Session,
) -> List[APIKey]:
    """Get all apikeys

    Args:
        user: User object
        include_disabled: Whether to include disabled apikeys
        offset: offset Results by given integer
        limit: limit Results by given integer
        session: Sqlmodel session

    Returns:
        A list of apikey objects
    """
    statement = select(APIKey).where(APIKey.user == user)
    if not include_disabled:
        statement = statement.where(APIKey.disabled == False)
    return session.exec(statement.offset(offset).limit(limit)).all()

In [None]:
#| export


@auth_router.get(
    "/{user_uuid_or_name}/apikey",
    response_model=List[APIKeyRead],
    responses={
        400: {
            "model": HTTPError,
            "description": ERRORS["INCORRECT_USER_UUID"],
        },
        401: {
            "model": HTTPError,
            "description": ERRORS["NOT_ENOUGH_PERMISSION_TO_ACCESS_OTHERS_DATA"],
        },
    },
)
def get_all_apikey(
    user_uuid_or_name: str,
    include_disabled: bool = False,
    offset: int = 0,
    limit: int = Query(default=100, lte=100),
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> List[APIKey]:
    """Get all apikeys created by user"""
    user = session.merge(user)
    return APIKey.get_all(  # type: ignore
        user=get_valid_user(user, session, user_uuid_or_name),
        include_disabled=include_disabled,
        offset=offset,
        limit=limit,
        session=session,
    )

In [None]:
# MFA enabled user trying to list all api-key's
with create_mfa_enabled_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]

    # Creating new API-key
    random_api_key_name = generate_random_name()
    apikey_to_create = APIKeyCreate(
        name=random_api_key_name,
        expiry=datetime.utcnow() + timedelta(days=1),
        otp=pyotp.TOTP(user.mfa_secret).now(),
    )
    original_api_key = create_apikey(
        apikey_to_create=apikey_to_create, user=user, session=session
    )
    assert original_api_key.access_token
    assert original_api_key.token_type == "bearer"

    # Positive scenario: passing valid user UUID in the request
    actual = get_all_apikey(
        user_uuid_or_name=str(user.uuid), offset=0, limit=1, user=user, session=session
    )
    assert len(actual) == 1
    assert isinstance(actual[0], APIKey)
    assert actual[0] == user.apikeys[0]
    
    # Positive scenario: passing valid username in the request
    actual = get_all_apikey(
        user_uuid_or_name=user.username, offset=0, limit=1, user=user, session=session
    )
    assert len(actual) == 1
    assert isinstance(actual[0], APIKey)
    assert actual[0] == user.apikeys[0]

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    actual = get_all_apikey(
        user_uuid_or_name=str(user.uuid), offset=0, limit=1, user=user, session=session
    )
    display(actual)

    assert len(actual) == 1
    assert isinstance(actual[0], APIKey)
    assert actual[0] == user.apikeys[0]

[APIKey(name=None, expiry=None, uuid=UUID('4827718d-592a-4aad-91cf-bea09ea0d5d0'), created=datetime.datetime(2022, 10, 21, 5, 14, 26), id=1, disabled=False, user_id=11)]

In [None]:
# Super user trying to access other users API keys
with get_session_with_context() as session:
    super_user = session.exec(select(User).where(User.username == "kumaran")).one()

    actual = get_all_apikey(
        user_uuid_or_name=str(user.uuid), offset=0, limit=1, user=super_user, session=session
    )
    display(actual)
    assert actual[0].user_id != super_user.id
    assert actual[0].user_id == user.id

[APIKey(name=None, expiry=None, uuid=UUID('4827718d-592a-4aad-91cf-bea09ea0d5d0'), created=datetime.datetime(2022, 10, 21, 5, 14, 26), id=1, disabled=False, user_id=11)]

In [None]:
actual = get_all_apikey(
    user_uuid_or_name=str(user.uuid),
    include_disabled=False,
    offset=0,
    limit=10,
    user=user,
    session=session,
)
display(f"{len(actual)=}")
for apikey in actual:
    assert not apikey.disabled

actual = get_all_apikey(
    user_uuid_or_name=str(user.uuid),
    include_disabled=True,
    offset=0,
    limit=10,
    user=user,
    session=session,
)
display(f"{len(actual)=}")
disabled_found = False
for apikey in actual:
    if apikey.disabled:
        disabled_found = True
        break
assert disabled_found

'len(actual)=7'

'len(actual)=10'