In [None]:
# | default_exp users

In [None]:
from airt.testing import activate_by_import

[INFO] airt.testing.activate_by_import: Testing environment activated.


2023-03-10 07:53:05.529943: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


[INFO] numexpr.utils: Note: NumExpr detected 16 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
[INFO] numexpr.utils: NumExpr defaulting to 8 threads.


In [None]:
# | export

import functools
import random
import re
import secrets
import string
import uuid
from typing import *

from airt.logger import get_logger
from airt.patching import patch
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel, EmailStr, validator
from sqlalchemy.exc import IntegrityError, NoResultFound
from sqlmodel import Session, select

import airt_service
import airt_service.sanitizer
from airt_service.auth import get_current_active_user, get_user, get_valid_user
from airt_service.cleanup import cleanup_user
from airt_service.db.models import (
    SMS,
    SSO,
    SMSProtocol,
    SSOBase,
    SSOProvider,
    SSORead,
    User,
    UserCreate,
    UserRead,
    get_session,
    get_session_with_context,
)
from airt_service.errors import ERRORS, HTTPError
from airt_service.helpers import commit_or_rollback, get_attr_by_name, get_password_hash
from airt_service.sms_utils import (
    get_app_and_message_id,
    get_application_and_message_config,
    send_sms,
    validate_otp,
    verify_pin,
)
from airt_service.sso import initiate_sso_flow
from airt_service.totp import (
    generate_mfa_provisioning_url,
    generate_mfa_secret,
    require_otp_if_mfa_enabled,
    validate_totp,
)

In [None]:
import json
import urllib
from contextlib import contextmanager
from datetime import datetime, timedelta
from os import environ

import pandas as pd
import pyotp
import pytest
import requests
from _pytest.monkeypatch import MonkeyPatch
from airt.remote_path import RemotePath
from fastapi import BackgroundTasks, Request
from pydantic import ValidationError
from sqlalchemy.exc import NoResultFound
from starlette.datastructures import Headers

from airt_service.auth import create_apikey
from airt_service.aws.utils import upload_to_s3_with_retry
from airt_service.constants import MFA_ISSUER_NAME
from airt_service.data.csv import process_csv
from airt_service.data.datablob import FromLocalRequest, from_local_start_route
from airt_service.db.models import (
    APIKeyCreate,
    DataBlob,
    DataSource,
    create_user_for_testing,
)
from airt_service.helpers import set_env_variable_context, verify_password
from airt_service.model.train import TrainRequest, predict_model, train_model

In [None]:
# | exporti

logger = get_logger(__name__)

In [None]:
test_username = create_user_for_testing()
display(test_username)

'kxzpllgnsa'

In [None]:
INVALID_UUID_FOR_TESTING = "00000000-0000-0000-0000-000000000000"

In [None]:
# | exporti

SEND_SMS_OTP_MSG = "If you have already registered and verified your phone number, you will receive the OTP by SMS. If you did not receive the OTP, please contact your administrator."
PASSWORD_RESET_MSG = "Password reset successful"  # nosec B105

In [None]:
# | export

# Default router for users
user_router = APIRouter(
    prefix="/user",
    tags=["user"],
    #     dependencies=[Depends(get_current_active_user)],
    responses={
        404: {"description": "Not found"},
        500: {
            "model": HTTPError,
            "description": ERRORS["INTERNAL_SERVER_ERROR"],
        },
    },
)

In [None]:
# | export


def ensure_super_user(func: Callable[..., Any]) -> Callable[..., Any]:
    """Decorator to ensure the user who executes the operation is a super user"""

    @functools.wraps(func)
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        if not kwargs["user"].super_user:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail=ERRORS["NOT_ENOUGH_PERMISSION"],
            )
        return func(*args, **kwargs)

    return wrapper

In [None]:
with get_session_with_context() as session:
    user_kumaran = session.exec(select(User).where(User.username == "kumaran")).one()

    @ensure_super_user
    def test_func(user):
        display(user)

    test_func(user=user_kumaran)

# Try to create it with user without create permission
with get_session_with_context() as session:
    not_super_user = session.exec(
        select(User).where(User.username == test_username)
    ).one()
    with pytest.raises(HTTPException) as e:
        test_func(user=not_super_user)
display(e)

User(id=3, uuid=UUID('5ffb486f-da50-4816-bb77-0ab63f9a0d4d'), username='kumaran', first_name='Kumaran', last_name='Rajendhiran', email='kumaran@airt.ai', subscription_type=<SubscriptionType.superuser: 'superuser'>, super_user=True, disabled=False, created=datetime.datetime(2023, 3, 9, 8, 17, 31), phone_number=None, is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)

<ExceptionInfo HTTPException(status_code=401, detail='You do not have sufficient permission to access this route. Please contact your administrator for help.') tblen=2>

In [None]:
# | export


class GenerateMFARresponse(BaseModel):
    """A base class for creating mfa url

    Args:
        mfa_url: The provisioning url generated from the secret
    """

    mfa_url: str

In [None]:
# | export


@user_router.get("/mfa/generate", response_model=GenerateMFARresponse)
@require_otp_if_mfa_enabled
def generate_mfa_url(
    otp: Optional[str] = None,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> GenerateMFARresponse:
    """Generate MFA url"""
    user = session.merge(user)

    mfa_secret = generate_mfa_secret()
    mfa_url = generate_mfa_provisioning_url(
        mfa_secret=mfa_secret, user_email=user.email
    )

    with commit_or_rollback(session):
        user.mfa_secret = mfa_secret
        session.add(user)

    return GenerateMFARresponse(mfa_url=mfa_url)

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    actual = generate_mfa_url(user=user, session=session)
    display(f"mfa_url={'*'*len(actual.mfa_url)}")

    assert len(actual.mfa_url)
    assert urllib.parse.quote(user.email) in actual.mfa_url
    assert MFA_ISSUER_NAME in actual.mfa_url

with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    assert len(user.mfa_secret) == 32
    assert not user.is_mfa_active
    user.is_mfa_active

'mfa_url=****************************************************************************************************'

In [None]:
# | export


class ActivateMFARequest(BaseModel):
    """A base class for activating mfa

    Args:
        user_otp: OTP passed by the user
    """

    user_otp: str

In [None]:
# | export


@user_router.post("/mfa/activate", response_model=UserRead)
def activate_mfa(
    activate_mfa_request: ActivateMFARequest,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> User:
    """Activate MFA"""
    user = session.merge(user)
    user_otp = activate_mfa_request.user_otp

    if not user.mfa_secret:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["GENERATE_MFA_URL_NOT_GENERATED"],
        )

    validate_totp(user.mfa_secret, user_otp)

    with commit_or_rollback(session):
        user.is_mfa_active = True
        session.add(user)

    return user

In [None]:
with get_session_with_context() as session:
    mfa_active_user = session.exec(
        select(User).where(User.username == test_username)
    ).one()

    activate_mfa_request = ActivateMFARequest(
        user_otp=pyotp.TOTP(mfa_active_user.mfa_secret).now()
    )

    actual = activate_mfa(
        activate_mfa_request=activate_mfa_request, user=mfa_active_user, session=session
    )
    display(actual)
    assert actual.is_mfa_active

    # Passing Random OTP
    activate_mfa_request = ActivateMFARequest(user_otp="123123")
    with pytest.raises(HTTPException) as e:
        activate_mfa(
            activate_mfa_request=activate_mfa_request,
            user=mfa_active_user,
            session=session,
        )

    display(e.value.detail)
    assert "Invalid OTP" in e.value.detail

# Calling activate route without calling generate route
random_username = create_user_for_testing()
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == random_username)).one()

    activate_mfa_request = ActivateMFARequest(user_otp="123123")
    with pytest.raises(HTTPException) as e:
        activate_mfa(
            activate_mfa_request=activate_mfa_request, user=user, session=session
        )

e.value

User(id=410, uuid=UUID('487479d2-a27f-46ae-939e-3f896dc77b2c'), username='kxzpllgnsa', first_name='unittest', last_name='user', email='kxzpllgnsa@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 9), phone_number=None, is_phone_number_verified=False, mfa_secret=**********************************, is_mfa_active=True)

'Invalid OTP. Please try again.'

HTTPException(status_code=400, detail='MFA code is not generated for this user, please call /user/mfa/generate first')

In [None]:
# | export


def get_user_to_disable_mfa(user: User, session: Session, user_uuid: str) -> User:
    """Get user object to disable MFA

    Only a super user can disable MFA for other users in the server

    Args:
        user: User object
        session: Sqlmodel session
        user_uuid: User uuid to disable MFA

    Returns:
        User object to disable MFA
    """
    _user = get_valid_user(user, session, user_uuid)

    if not _user.is_mfa_active:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["MFA_ALREADY_DISABLED"],
        )

    return _user

In [None]:
with get_session_with_context() as session:
    user = session.exec(
        select(User).where(User.username == mfa_active_user.username)
    ).one()

    # Negative Scenario: Normal user disabling MFA for others
    with pytest.raises(HTTPException) as e:
        get_user_to_disable_mfa(user, session, INVALID_UUID_FOR_TESTING)

    # Positive Scenario: MFA enabled Normal user disabling MFA for self
    actual = get_user_to_disable_mfa(user, session, str(user.uuid))
    assert actual.id == user.id
    display(actual)

# Negative Scenario: MFA disabled Normal User disabling MFA for self
test_username = create_user_for_testing()
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    with pytest.raises(HTTPException) as e:
        get_user_to_disable_mfa(user, session, str(user.uuid))

User(id=410, uuid=UUID('487479d2-a27f-46ae-939e-3f896dc77b2c'), username='kxzpllgnsa', first_name='unittest', last_name='user', email='kxzpllgnsa@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 9), phone_number=None, is_phone_number_verified=False, mfa_secret=**********************************, is_mfa_active=True)

In [None]:
test_username = create_user_for_testing()
with get_session_with_context() as session:
    other_user = session.exec(select(User).where(User.username == test_username)).one()

with get_session_with_context() as session:
    super_user = session.exec(select(User).where(User.username == "kumaran")).one()

    # Negative Scenario: MFA disabled Super user disabling MFA for self
    with pytest.raises(HTTPException) as e:
        get_user_to_disable_mfa(super_user, session, str(super_user.uuid))

    # Negative Scenario: MFA disabled Super user disabling MFA for invalid user id
    with pytest.raises(HTTPException) as e:
        get_user_to_disable_mfa(super_user, session, INVALID_UUID_FOR_TESTING)

    # Positive Scenario: MFA disabled Super user disabling MFA for another MFA disabled user
    with pytest.raises(HTTPException) as e:
        get_user_to_disable_mfa(super_user, session, str(other_user.uuid))

    # Positive Scenario: MFA disabled Super user disabling MFA for MFA enabled other user
    actual = get_user_to_disable_mfa(super_user, session, str(mfa_active_user.uuid))
    assert not actual.id == super_user.id
    assert actual.id == mfa_active_user.id

    # Positive Scenario: MFA enabled Super user disabling MFA self
    super_user.is_mfa_active = True
    actual = get_user_to_disable_mfa(super_user, session, str(super_user.uuid))
    assert actual.id == super_user.id
    display(actual)

User(id=3, uuid=UUID('5ffb486f-da50-4816-bb77-0ab63f9a0d4d'), username='kumaran', first_name='Kumaran', last_name='Rajendhiran', email='kumaran@airt.ai', subscription_type=<SubscriptionType.superuser: 'superuser'>, super_user=True, disabled=False, created=datetime.datetime(2023, 3, 9, 8, 17, 31), phone_number=None, is_phone_number_verified=False, mfa_secret=****, is_mfa_active=True)

In [None]:
# | exporti


def create_sms_protocol(xs: Dict[str, str], sms: SMS, session: Session) -> None:
    """Create a new record in the sms protocol table

    Args:
        xs: The response from infobip's send sms API
        sms: Instance of the SMS db model
        session: Session object
    """

    with commit_or_rollback(session):
        sms_protocol = SMSProtocol(
            pin_id=xs["pinId"],
            number_lookup_status=xs["ncStatus"],
            sent_sms_status=xs["smsStatus"],
            phone_number=xs["to"],
        )
        sms_protocol.sms = sms
        session.add(sms_protocol)

In [None]:
session = next(get_session())
sample_response = {
    "pinId": "my_random_pin_id",
    "to": "910000000000",
    "ncStatus": "NC_NOT_CONFIGURED",
    "smsStatus": "MESSAGE_SENT",
}
test_username = create_user_for_testing()
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    with commit_or_rollback(session):
        sms = SMS(application_id="000000", message_id="000000")
        sms.user = user
        session.add(sms)

    create_sms_protocol(sample_response, sms, session)

    sms_protocol = session.exec(
        select(SMSProtocol).where(SMSProtocol.sms_id == sms.id)
    ).one()

    display(sms_protocol)
    assert sms_protocol.sms_id == sms.id
    assert sms_protocol.pin_id == sample_response["pinId"]
    assert sms_protocol.phone_number == sample_response["to"]

SMSProtocol(id=30, number_lookup_status='NC_NOT_CONFIGURED', phone_number='910000000000', pin_attempts_remaining=None, pin_id='my_random_pin_id', sent_sms_status='MESSAGE_SENT', pin_verified=False, sms_id=30)

In [None]:
# | exporti


def _get_allowed_message_template_names() -> List[str]:
    """Get valid message templates for sending SMS

    The **register_phone_number** template will be removed from the allowed list because
    while registering the phone number the users should use the /register_phone_number
    route for sending the SMS and use /validate_phone_number route to validate.

    Returns:
        The list of valid message template names
    """
    message_template_names = list(
        get_application_and_message_config()["message_config"].keys()
    )
    message_template_names.remove("register_phone_number")

    return message_template_names

In [None]:
expected = ["reset_password", "disable_mfa", "get_token"]
actual = _get_allowed_message_template_names()

display(actual)
assert actual == expected

['reset_password', 'disable_mfa', 'get_token']

In [None]:
# | exporti


def _send_sms_otp_to_user(
    user: User,
    message_template_name: str,
    session: Session,
    phone_number: Optional[str] = None,
) -> User:
    """Send the OTP via SMS to the user for the given message template

    Args:
        user: User object for whom the SMS needs to be sent
        message_template_name: Message template name to include in the SMS
        phone_number: The phone number of the user to send SMS. If this setting is passed, then the
            SMS will be sent to this phone number in place of the one stored in the database. This will
            allow the user if they wish to register a new phone number or change an existing one.
        session: Session object

    Returns:
        The user object if the SMS is sent successfully

    Raises:
        HTTPException: If the Infobip server is not reachable or not able to send SMS
        HTTPException: If the user requests for more sms's to the same number than the allocated limit. Currently
            the limit is set to 30 messages per day to one phone number.
    """
    user = session.merge(user)

    application_id, message_id = get_app_and_message_id(
        message_template_name=message_template_name
    )

    sms = session.exec(
        select(SMS)
        .where(SMS.user == user)
        .where(SMS.application_id == application_id)
        .where(SMS.message_id == message_id)
    ).one_or_none()

    if sms is None:
        with commit_or_rollback(session):
            sms = SMS(application_id=application_id, message_id=message_id)
            sms.user = user
            session.add(sms)

    phone_number = phone_number if phone_number is not None else user.phone_number

    sms_send_status = airt_service.sms_utils.send_sms(
        sms.application_id, sms.message_id, phone_number
    )

    if "requestError" in sms_send_status:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=f"{sms_send_status['requestError']['serviceException']['text']}",
        )

    if sms_send_status["smsStatus"] == "MESSAGE_NOT_SENT":
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["MESSAGE_NOT_SENT"],
        )

    sms_protocol = session.exec(
        select(SMSProtocol).where(SMSProtocol.sms_id == sms.id)
    ).one_or_none()

    if sms_protocol is None:
        create_sms_protocol(sms_send_status, sms, session)
    else:
        with commit_or_rollback(session):
            session.delete(sms_protocol)

        create_sms_protocol(sms_send_status, sms, session)

    return user

In [None]:
# | export


@user_router.get("/send_sms_otp")
def send_sms_otp(
    username: str,
    message_template_name: str,
    session: Session = Depends(get_session),
) -> str:
    """Send the OTP via SMS to the user"""

    user = get_user(username)

    if user is not None:
        if (user.phone_number is not None) and (user.is_phone_number_verified):
            allowed_message_template_names = _get_allowed_message_template_names()
            if message_template_name in allowed_message_template_names:
                user = _send_sms_otp_to_user(
                    user=user,
                    message_template_name=message_template_name,
                    session=session,
                )

    return SEND_SMS_OTP_MSG

In [None]:
# Tests for send_sms_otp:
# Negative Scenario: User phone number is not updated in the database

test_username = create_user_for_testing()
message_template_name = "invalid_message_template_name"
actual = send_sms_otp(
    username=test_username, message_template_name=message_template_name, session=session
)
display(actual)

with pytest.raises(NoResultFound) as e:
    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.username == test_username)).one()
        sms = session.exec(select(SMS).where(SMS.user == user)).one()

assert "No row was found when one was required" in str(e.value)
str(e.value)

'If you have already registered and verified your phone number, you will receive the OTP by SMS. If you did not receive the OTP, please contact your administrator.'

'No row was found when one was required'

In [None]:
# Tests for send_sms_otp:
# Negative Scenario: User phone number is updated but not verified in the database

test_username = create_user_for_testing()
message_template_name = "reset_password"
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    user.phone_number = "91123456789"
    session.add(user)
    session.commit()

    actual = send_sms_otp(
        username=test_username,
        message_template_name=message_template_name,
        session=session,
    )
    display(actual)

with pytest.raises(NoResultFound) as e:
    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.username == test_username)).one()
        sms = session.exec(select(SMS).where(SMS.user == user)).one()

assert "No row was found when one was required" in str(e.value)
str(e.value)

'If you have already registered and verified your phone number, you will receive the OTP by SMS. If you did not receive the OTP, please contact your administrator.'

'No row was found when one was required'

In [None]:
# Tests for send_sms_otp:
# Negative Scenario: User phone number is updated and verified but invalid message template is passed

for message_template_name in ["invalid_message_template_name", "register_phone_number"]:
    test_username = create_user_for_testing()
    random_phone_number = "910000000000"
    random_sms_pin_id = "my_random_pin_id"
    with MonkeyPatch.context() as monkeypatch:
        send_sms_sample_response = {
            "pinId": random_sms_pin_id,
            "to": random_phone_number,
            "ncStatus": "NC_NOT_CONFIGURED",
            "smsStatus": "MESSAGE_SENT",
        }
        monkeypatch.setattr(
            "airt_service.sms_utils.send_sms", lambda x, y, z: send_sms_sample_response
        )

        with get_session_with_context() as session:
            user = session.exec(
                select(User).where(User.username == test_username)
            ).one()
            user.phone_number = random_phone_number
            user.is_phone_number_verified = True
            session.add(user)
            session.commit()

        actual = send_sms_otp(
            username=test_username,
            message_template_name=message_template_name,
            session=session,
        )
        display(actual)

    with pytest.raises(NoResultFound) as e:
        with get_session_with_context() as session:
            user = session.exec(
                select(User).where(User.username == test_username)
            ).one()
            sms = session.exec(select(SMS).where(SMS.user == user)).one()

    assert "No row was found when one was required" in str(e.value)
    display(str(e.value))

'If you have already registered and verified your phone number, you will receive the OTP by SMS. If you did not receive the OTP, please contact your administrator.'

'No row was found when one was required'

'If you have already registered and verified your phone number, you will receive the OTP by SMS. If you did not receive the OTP, please contact your administrator.'

'No row was found when one was required'

In [None]:
# Tests for send_sms_otp:
# Positive Scenario: User phone number is updated and verified and a valid message template is passed

test_username = create_user_for_testing()
random_phone_number = "910000000000"
random_sms_pin_id = "my_random_pin_id"
message_template_name = "reset_password"
with MonkeyPatch.context() as monkeypatch:
    send_sms_sample_response = {
        "pinId": random_sms_pin_id,
        "to": random_phone_number,
        "ncStatus": "NC_NOT_CONFIGURED",
        "smsStatus": "MESSAGE_SENT",
    }
    monkeypatch.setattr(
        "airt_service.sms_utils.send_sms", lambda x, y, z: send_sms_sample_response
    )

    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.username == test_username)).one()
        user.phone_number = random_phone_number
        user.is_phone_number_verified = True
        session.add(user)
        session.commit()

    actual = send_sms_otp(
        username=test_username,
        message_template_name=message_template_name,
        session=session,
    )
    display(actual)

with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    sms = session.exec(select(SMS).where(SMS.user == user)).one()
    display(sms)
    assert sms.user_id == user.id

    sms_protocol = session.exec(
        select(SMSProtocol).where(SMSProtocol.sms_id == sms.id)
    ).one()

    display(sms_protocol)
    assert sms_protocol.sms_id == sms.id
    assert sms_protocol.pin_id == random_sms_pin_id
    assert sms_protocol.phone_number == random_phone_number

'If you have already registered and verified your phone number, you will receive the OTP by SMS. If you did not receive the OTP, please contact your administrator.'

SMS(message_id='74444CB4E552B0BF5E57DC305B49DD64', application_id='0636B234BE48C57A95BB07AA93A5483E', id=31, user_id=419)

SMSProtocol(id=31, number_lookup_status='NC_NOT_CONFIGURED', phone_number='910000000000', pin_attempts_remaining=None, pin_id='my_random_pin_id', sent_sms_status='MESSAGE_SENT', pin_verified=False, sms_id=31)

In [None]:
# | export


def require_otp_or_totp_if_mfa_enabled(
    message_template_name: str,
) -> Callable[..., Any]:
    """A decorator function to validate the totp/otp for MFA enabled user

    If the totp/otp validation fails, the user will not be granted access to the decorated route

    Args:
        message_template_name: Name of the message template that was used to send the SMS
    """

    def outer_wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
        @functools.wraps(func)
        def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
            user = kwargs["user"]
            session = kwargs["session"]
            otp_or_totp = get_attr_by_name(kwargs, "otp")

            if not user.is_mfa_active and otp_or_totp is not None:
                raise HTTPException(
                    status_code=status.HTTP_400_BAD_REQUEST,
                    detail=ERRORS["MFA_NOT_ACTIVATED_BUT_PASSES_OTP"],
                )

            if user.is_mfa_active:
                if otp_or_totp is not None:
                    try:
                        validate_totp(user.mfa_secret, otp_or_totp)
                    except HTTPException as e:
                        try:
                            validate_otp(
                                user=user,
                                otp=otp_or_totp,
                                message_template_name=message_template_name,
                                session=session,
                            )
                        except HTTPException as e:
                            raise HTTPException(
                                status_code=status.HTTP_400_BAD_REQUEST,
                                detail=ERRORS["INVALID_OTP"],
                            )
                else:
                    raise HTTPException(
                        status_code=status.HTTP_400_BAD_REQUEST,
                        detail=ERRORS["OTP_REQUIRED"],
                    )

            # Do something before
            return func(*args, **kwargs)
            # Do something after

        return inner_wrapper

    return outer_wrapper

In [None]:
@require_otp_or_totp_if_mfa_enabled(message_template_name="disable_mfa")
def test_require_otp_or_totp_if_mfa_enabled(
    otp,
    user,
    session,
):
    return "Ok"


with get_session_with_context() as session:
    test_user = create_user_for_testing()
    user = session.exec(select(User).where(User.username == test_user)).one()
    with pytest.raises(HTTPException) as e:
        random_otp = "123123"
        test_require_otp_or_totp_if_mfa_enabled(
            otp=random_otp, user=user, session=session
        )
    display(e.value.detail)
    assert (
        e.value.detail
        == "MFA is not activated for the account. Please pass the OTP only after activating the MFA for your account."
    )

    actual = test_require_otp_or_totp_if_mfa_enabled(
        otp=None, user=user, session=session
    )
    assert actual == "Ok"
    display(actual)

    actual = generate_mfa_url(user=user, session=session)

    activate_mfa_request = ActivateMFARequest(
        user_otp=pyotp.TOTP(user.mfa_secret).now()
    )

    actual = activate_mfa(
        activate_mfa_request=activate_mfa_request, user=user, session=session
    )
    display(actual)
    assert actual.is_mfa_active

    with pytest.raises(HTTPException) as e:
        random_otp = "123123"
        test_require_otp_or_totp_if_mfa_enabled(
            otp=random_otp, user=user, session=session
        )
    display(e.value.detail)
    assert "Invalid OTP" in e.value.detail

    with pytest.raises(HTTPException) as e:
        random_otp = None
        test_require_otp_or_totp_if_mfa_enabled(
            otp=random_otp, user=user, session=session
        )
    display(e.value.detail)
    assert "OTP is required" in e.value.detail

    valid_totp = pyotp.TOTP(user.mfa_secret).now()
    actual = test_require_otp_or_totp_if_mfa_enabled(
        otp=valid_totp, user=user, session=session
    )
    assert actual == "Ok"
    display(actual)

'MFA is not activated for the account. Please pass the OTP only after activating the MFA for your account.'

'Ok'

User(id=420, uuid=UUID('8ced6951-8fcc-45d5-9544-e0bc57392724'), username='fgeigjnfcq', first_name='unittest', last_name='user', email='fgeigjnfcq@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 20), phone_number=None, is_phone_number_verified=False, mfa_secret=**********************************, is_mfa_active=True)

'Invalid OTP. Please try again.'

'OTP is required. Please enter the OTP generated by the authenticator app or the one you requested via SMS.'

'Ok'

In [None]:
# | export


@user_router.delete(
    "/mfa/{user_uuid_or_name}/disable",
    response_model=UserRead,
    responses={
        400: {
            "model": HTTPError,
            "description": ERRORS["INCORRECT_USER_UUID"],
        },
        401: {
            "model": HTTPError,
            "description": ERRORS["NOT_ENOUGH_PERMISSION_TO_ACCESS_OTHERS_DATA"],
        },
    },
)
@require_otp_or_totp_if_mfa_enabled(message_template_name="disable_mfa")
def disable_mfa(
    user_uuid_or_name: str,
    otp: Optional[str] = None,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> User:
    """Disable MFA"""

    user = session.merge(user)
    user_to_disable_mfa = get_user_to_disable_mfa(user, session, user_uuid_or_name)

    with commit_or_rollback(session):
        user_to_disable_mfa.is_mfa_active = False
        user_to_disable_mfa.mfa_secret = None
        session.add(user_to_disable_mfa)

    return user_to_disable_mfa

In [None]:
# Positive scenario: MFA enabled normal user trying to disable mfa for self (using TOTP)
with get_session_with_context() as session:
    user = session.exec(
        select(User).where(User.username == mfa_active_user.username)
    ).one()

with get_session_with_context() as session:
    expected = session.exec(select(User).where(User.username == user.username)).one()
    actual = disable_mfa(
        user_uuid_or_name=str(user.uuid),
        otp=pyotp.TOTP(user.mfa_secret).now(),
        user=user,
        session=session,
    )

    display(actual)
    assert actual.id == expected.id

    # Negative scenario: normal user trying to disable mfa for others
    with pytest.raises(HTTPException) as e:
        disable_mfa(
            user_uuid_or_name=INVALID_UUID_FOR_TESTING, user=user, session=session
        )

User(id=410, uuid=UUID('487479d2-a27f-46ae-939e-3f896dc77b2c'), username='kxzpllgnsa', first_name='unittest', last_name='user', email='kxzpllgnsa@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 9), phone_number=None, is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)

In [None]:
# Positive scenario: MFA enabled normal user trying to disable mfa using sms otp
test_username = create_user_for_testing()
random_phone_number = "910000000000"
random_sms_pin_id = "my_random_pin_id"
message_template_name = "disable_mfa"
random_sms_otp = "111111"
with MonkeyPatch.context() as monkeypatch:
    send_sms_sample_response = {
        "pinId": random_sms_pin_id,
        "to": random_phone_number,
        "ncStatus": "NC_NOT_CONFIGURED",
        "smsStatus": "MESSAGE_SENT",
    }
    monkeypatch.setattr(
        "airt_service.sms_utils.send_sms", lambda x, y, z: send_sms_sample_response
    )

    verify_pin_sample_response = {
        "pinId": random_sms_pin_id,
        "msisdn": random_phone_number,
        "verified": True,
        "attemptsRemaining": 0,
    }

    monkeypatch.setattr(
        "airt_service.sms_utils.verify_pin", lambda x, y: verify_pin_sample_response
    )

    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.username == test_username)).one()
        user.phone_number = random_phone_number
        user.is_phone_number_verified = True
        session.add(user)
        session.commit()
        session.refresh(user)

        actual = generate_mfa_url(user=user, session=session)
        assert user.mfa_secret is not None
        # activate MFA
        activate_mfa_request = ActivateMFARequest(
            user_otp=pyotp.TOTP(user.mfa_secret).now()
        )
        actual = activate_mfa(
            activate_mfa_request=activate_mfa_request, user=user, session=session
        )
        assert actual.is_mfa_active
        display(actual)

        actual = send_sms_otp(
            username=test_username,
            message_template_name=message_template_name,
            session=session,
        )
        display(actual)

        user = disable_mfa(
            user_uuid_or_name=str(user.uuid),
            otp=random_sms_otp,
            user=user,
            session=session,
        )

        display(user)
        assert not user.is_mfa_active

User(id=421, uuid=UUID('b102337b-45bc-476f-b5ab-7775507ae135'), username='rhbyhtudpu', first_name='unittest', last_name='user', email='rhbyhtudpu@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 24), phone_number='910000000000', is_phone_number_verified=True, mfa_secret=**********************************, is_mfa_active=True)

'If you have already registered and verified your phone number, you will receive the OTP by SMS. If you did not receive the OTP, please contact your administrator.'

User(id=421, uuid=UUID('b102337b-45bc-476f-b5ab-7775507ae135'), username='rhbyhtudpu', first_name='unittest', last_name='user', email='rhbyhtudpu@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 24), phone_number='910000000000', is_phone_number_verified=True, mfa_secret=****, is_mfa_active=False)

In [None]:
# Positive scenario: Super user trying to disable mfa for other users (using TOTP)
# create a super user and activate MFA
test_super_username = create_user_for_testing()
with get_session_with_context() as session:
    super_user = session.exec(
        select(User).where(User.username == test_super_username)
    ).one()
    super_user.super_user = True
    session.add(super_user)
    session.commit()
    session.refresh(super_user)

    mfa_url = generate_mfa_url(user=super_user, session=session)
    activate_mfa_request = ActivateMFARequest(
        user_otp=pyotp.TOTP(super_user.mfa_secret).now()
    )
    actual = activate_mfa(
        activate_mfa_request=activate_mfa_request, user=super_user, session=session
    )
    assert actual.is_mfa_active
    display(super_user)

# create a normal user and activate MFA
with get_session_with_context() as session:
    test_username = create_user_for_testing()
    normal_user = session.exec(select(User).where(User.username == test_username)).one()
    mfa_url = generate_mfa_url(user=normal_user, session=session)
    display(f"mfa_url={'*'*len(mfa_url.mfa_url)}")

    activate_mfa_request = ActivateMFARequest(
        user_otp=pyotp.TOTP(normal_user.mfa_secret).now()
    )
    actual = activate_mfa(
        activate_mfa_request=activate_mfa_request, user=normal_user, session=session
    )
    assert actual.is_mfa_active
    display(normal_user)

# login in as super user and disable MFA for the normal user
with get_session_with_context() as session:
    super_user = session.exec(
        select(User).where(User.username == test_super_username)
    ).one()
    display(super_user)

    actual = disable_mfa(
        user_uuid_or_name=str(normal_user.uuid),
        otp=pyotp.TOTP(super_user.mfa_secret).now(),
        user=super_user,
        session=session,
    )

    display(actual)
    assert actual.id == normal_user.id
    assert not actual.is_mfa_active

# disable the super user
with get_session_with_context() as session:
    super_user = session.exec(
        select(User).where(User.username == test_super_username)
    ).one()
    super_user.super_user = False
    super_user.disabled = True
    session.add(super_user)
    session.commit()
    session.refresh(super_user)

display(super_user)

User(id=422, uuid=UUID('eaf35757-2555-4029-b371-4c9f0b3e661b'), username='uueddugbqi', first_name='unittest', last_name='user', email='uueddugbqi@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=True, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 33), phone_number=None, is_phone_number_verified=False, mfa_secret=**********************************, is_mfa_active=True)

'mfa_url=****************************************************************************************************'

User(id=423, uuid=UUID('016ff5b6-4e10-4957-8227-27fb3bfd7bbc'), username='wbkaewznuo', first_name='unittest', last_name='user', email='wbkaewznuo@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 33), phone_number=None, is_phone_number_verified=False, mfa_secret=**********************************, is_mfa_active=True)

User(id=422, uuid=UUID('eaf35757-2555-4029-b371-4c9f0b3e661b'), username='uueddugbqi', first_name='unittest', last_name='user', email='uueddugbqi@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=True, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 33), phone_number=None, is_phone_number_verified=False, mfa_secret=**********************************, is_mfa_active=True)

User(id=423, uuid=UUID('016ff5b6-4e10-4957-8227-27fb3bfd7bbc'), username='wbkaewznuo', first_name='unittest', last_name='user', email='wbkaewznuo@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 33), phone_number=None, is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)

User(id=422, uuid=UUID('eaf35757-2555-4029-b371-4c9f0b3e661b'), username='uueddugbqi', first_name='unittest', last_name='user', email='uueddugbqi@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=True, created=datetime.datetime(2023, 3, 10, 7, 53, 33), phone_number=None, is_phone_number_verified=False, mfa_secret=**********************************, is_mfa_active=True)

In [None]:
# Positive scenario: Super user trying to disable mfa for other users (using SMS OTP)
# create a super user and activate MFA
test_super_username = create_user_for_testing()
random_phone_number = "910000000000"
random_sms_pin_id = "my_random_pin_id"
message_template_name = "disable_mfa"
random_sms_otp = "111111"

with get_session_with_context() as session:
    super_user = session.exec(
        select(User).where(User.username == test_super_username)
    ).one()
    super_user.super_user = True
    super_user.phone_number = random_phone_number
    super_user.is_phone_number_verified = True
    session.add(super_user)
    session.commit()
    session.refresh(super_user)

    mfa_url = generate_mfa_url(user=super_user, session=session)
    activate_mfa_request = ActivateMFARequest(
        user_otp=pyotp.TOTP(super_user.mfa_secret).now()
    )
    actual = activate_mfa(
        activate_mfa_request=activate_mfa_request, user=super_user, session=session
    )
    assert actual.is_mfa_active
    display(super_user)

# create a normal user and activate MFA
with get_session_with_context() as session:
    test_username = create_user_for_testing()
    normal_user = session.exec(select(User).where(User.username == test_username)).one()
    mfa_url = generate_mfa_url(user=normal_user, session=session)
    display(f"mfa_url={'*'*len(mfa_url.mfa_url)}")

    activate_mfa_request = ActivateMFARequest(
        user_otp=pyotp.TOTP(normal_user.mfa_secret).now()
    )
    actual = activate_mfa(
        activate_mfa_request=activate_mfa_request, user=normal_user, session=session
    )
    assert actual.is_mfa_active
    display(normal_user)

with MonkeyPatch.context() as monkeypatch:
    send_sms_sample_response = {
        "pinId": random_sms_pin_id,
        "to": random_phone_number,
        "ncStatus": "NC_NOT_CONFIGURED",
        "smsStatus": "MESSAGE_SENT",
    }
    monkeypatch.setattr(
        "airt_service.sms_utils.send_sms", lambda x, y, z: send_sms_sample_response
    )

    verify_pin_sample_response = {
        "pinId": random_sms_pin_id,
        "msisdn": random_phone_number,
        "verified": True,
        "attemptsRemaining": 0,
    }

    monkeypatch.setattr(
        "airt_service.sms_utils.verify_pin", lambda x, y: verify_pin_sample_response
    )

    # login in as super user and disable MFA for the normal user
    with get_session_with_context() as session:
        super_user = session.exec(
            select(User).where(User.username == test_super_username)
        ).one()
        display(super_user)

        actual = send_sms_otp(
            username=test_super_username,
            message_template_name=message_template_name,
            session=session,
        )
        display(actual)

        actual = disable_mfa(
            user_uuid_or_name=str(normal_user.uuid),
            otp=random_sms_otp,
            user=super_user,
            session=session,
        )

        display(actual)
        assert actual.id == normal_user.id
        assert not actual.is_mfa_active

# disable the super user
with get_session_with_context() as session:
    super_user = session.exec(
        select(User).where(User.username == test_super_username)
    ).one()
    super_user.super_user = False
    super_user.disabled = True
    session.add(super_user)
    session.commit()
    session.refresh(super_user)

display(super_user)

User(id=424, uuid=UUID('472f5b85-b9c3-4102-ba94-2f028a1011c9'), username='wzuafnykaj', first_name='unittest', last_name='user', email='wzuafnykaj@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=True, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 34), phone_number='910000000000', is_phone_number_verified=True, mfa_secret=**********************************, is_mfa_active=True)

'mfa_url=****************************************************************************************************'

User(id=425, uuid=UUID('9e7bfe37-bc47-483e-aae8-5d6b1a635cc1'), username='sqfarrgeoo', first_name='unittest', last_name='user', email='sqfarrgeoo@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 34), phone_number=None, is_phone_number_verified=False, mfa_secret=**********************************, is_mfa_active=True)

User(id=424, uuid=UUID('472f5b85-b9c3-4102-ba94-2f028a1011c9'), username='wzuafnykaj', first_name='unittest', last_name='user', email='wzuafnykaj@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=True, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 34), phone_number='910000000000', is_phone_number_verified=True, mfa_secret=**********************************, is_mfa_active=True)

'If you have already registered and verified your phone number, you will receive the OTP by SMS. If you did not receive the OTP, please contact your administrator.'

User(id=425, uuid=UUID('9e7bfe37-bc47-483e-aae8-5d6b1a635cc1'), username='sqfarrgeoo', first_name='unittest', last_name='user', email='sqfarrgeoo@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 34), phone_number=None, is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)

User(id=424, uuid=UUID('472f5b85-b9c3-4102-ba94-2f028a1011c9'), username='wzuafnykaj', first_name='unittest', last_name='user', email='wzuafnykaj@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=True, created=datetime.datetime(2023, 3, 10, 7, 53, 34), phone_number='910000000000', is_phone_number_verified=True, mfa_secret=**********************************, is_mfa_active=True)

In [None]:
# Context manager to create MFA enabled user


@contextmanager
def create_mfa_enabled_user():
    with get_session_with_context() as session:
        mfa_enabled_user = create_user_for_testing()
        user = session.exec(select(User).where(User.username == mfa_enabled_user)).one()
        try:
            # generate MFA
            actual = generate_mfa_url(user=user, session=session)
            assert user.mfa_secret is not None
            # activate MFA
            activate_mfa_request = ActivateMFARequest(
                user_otp=pyotp.TOTP(user.mfa_secret).now()
            )
            actual = activate_mfa(
                activate_mfa_request=activate_mfa_request, user=user, session=session
            )
            yield user, session
        finally:
            # deactivate MFA
            user = disable_mfa(
                user_uuid_or_name=str(user.uuid),
                otp=pyotp.TOTP(user.mfa_secret).now(),
                user=user,
                session=session,
            )
            with commit_or_rollback(session):
                user.disabled = True
                session.add(user)
            assert user.mfa_secret is None
            assert user.disabled


with create_mfa_enabled_user() as user_and_session:
    display(user_and_session[0])
    display(user_and_session[1])

User(id=426, uuid=UUID('dee992db-e49e-4d39-ba19-9804d5b2c052'), username='kvprnvowrn', first_name='unittest', last_name='user', email='kvprnvowrn@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 42), phone_number=None, is_phone_number_verified=False, mfa_secret=**********************************, is_mfa_active=True)

<sqlmodel.orm.session.Session>

In [None]:
# MFA enabled user trying to generate mfa url
with create_mfa_enabled_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    # Negative scenario: Not passing the OTP in the request
    with pytest.raises(HTTPException) as e:
        generate_mfa_url(user=user, session=session)
    assert "OTP is required" in str(e.value.detail)
    display(e.value.detail)
    # Negative scenario: passing invalid OTP in the request
    with pytest.raises(HTTPException) as e:
        random_otp = 111111
        generate_mfa_url(otp=random_otp, user=user, session=session)
    assert "Invalid OTP" in str(e.value.detail), str(e.value.detail)
    display(e.value.detail)

    # Positive scenario: passing valid OTP in the request
    actual = generate_mfa_url(
        otp=pyotp.TOTP(user.mfa_secret).now(), user=user, session=session
    )
    display(f'mfa_url: {"*"*len(actual.mfa_url)}')

# Negatve scenario: Non MFA user passing OTP in the request
with pytest.raises(HTTPException) as e:
    random_otp = 111111
    generate_mfa_url(otp=random_otp, user=user, session=session)
assert (
    str(e.value.detail)
    == "MFA is not activated for the account. Please pass the OTP only after activating the MFA for your account."
)
display(e.value.detail)

'OTP is required. Please enter the OTP generated by the authenticator app or the one you requested via SMS.'

'Invalid OTP. Please try again.'

'mfa_url: ****************************************************************************************************'

'MFA is not activated for the account. Please pass the OTP only after activating the MFA for your account.'

In [None]:
# MFA enabled user trying to disable mfa
with create_mfa_enabled_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    # Negative scenario: Not passing the OTP in the request
    with pytest.raises(HTTPException) as e:
        disable_mfa(user_uuid_or_name=user.uuid, user=user, session=session)
    assert "OTP is required" in str(e.value.detail)
    display(e.value.detail)
    # Negative scenario: passing invalid OTP in the request
    with pytest.raises(HTTPException) as e:
        random_otp = 111111
        disable_mfa(
            user_uuid_or_name=user.uuid, otp=random_otp, user=user, session=session
        )
    assert "Invalid OTP" in str(e.value.detail), str(e.value.detail)
    display(e.value.detail)

# Negatve scenario: Non MFA user passing OTP in the request
with pytest.raises(HTTPException) as e:
    random_otp = 111111
    disable_mfa(user_uuid_or_name=user.uuid, otp=random_otp, user=user, session=session)
assert (
    str(e.value.detail)
    == "MFA is not activated for the account. Please pass the OTP only after activating the MFA for your account."
)
display(e.value.detail)

'OTP is required. Please enter the OTP generated by the authenticator app or the one you requested via SMS.'

'Invalid OTP. Please try again.'

'MFA is not activated for the account. Please pass the OTP only after activating the MFA for your account.'

In [None]:
# | exporti


@patch(cls_method=True)  # type: ignore
def _create(cls: User, user_to_create: UserCreate, session: Session) -> User:
    """Method to create new user

    Args:
        user_to_create: UserCreate object
        session: DB session object

    Returns:
        A newly created user object

    Raises:
        HTTPException: if username or email already exists in database
    """
    user_to_create.password = get_password_hash(user_to_create.password)
    new_user = User(**user_to_create.dict())

    try:
        session.add(new_user)
        session.commit()
        session.refresh(new_user)
    except IntegrityError:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["USERNAME_OR_EMAIL_ALREADY_EXISTS"],
        )
#     create_topics_for_user(username=new_user.username)
    return new_user

In [None]:
# | export


@user_router.post(
    "/",
    response_model=UserRead,
    responses={
        400: {
            "model": HTTPError,
            "description": ERRORS["USERNAME_OR_EMAIL_ALREADY_EXISTS"],
        },
        401: {"model": HTTPError, "description": ERRORS["NOT_ENOUGH_PERMISSION"]},
    },
)
@require_otp_if_mfa_enabled
@ensure_super_user
def create_user(
    user_to_create: UserCreate,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> User:
    """
    Create new user
    """
    user = session.merge(user)

    return User._create(user_to_create, session)  # type: ignore

In [None]:
# Context manager to create MFA enabled super user
@contextmanager
def create_mfa_enabled_super_user():
    with get_session_with_context() as session:
        user_kumaran = session.exec(
            select(User).where(User.username == "kumaran")
        ).one()

        username = "".join(random.choice(string.ascii_lowercase) for _ in range(10))

        super_user_to_create = UserCreate(
            username=username,
            first_name=f"first_name_{username}",
            last_name=f"last_name_{username}",
            email=f"{username}@email.com",
            password=username,
            subscription_type="test",
            super_user=True,
            otp=None,
        )
        actual = create_user(
            user_to_create=super_user_to_create, user=user_kumaran, session=session
        )
        assert actual.username == super_user_to_create.username
        assert actual.username is not None
        assert actual.super_user

    with get_session_with_context() as session:
        try:
            user = session.exec(
                select(User).where(User.username == super_user_to_create.username)
            ).one()
            # generate MFA
            actual = generate_mfa_url(user=user, session=session)
            assert user.mfa_secret is not None
            # activate MFA
            activate_mfa_request = ActivateMFARequest(
                user_otp=pyotp.TOTP(user.mfa_secret).now()
            )
            actual = activate_mfa(
                activate_mfa_request=activate_mfa_request, user=user, session=session
            )
            yield user, session
        finally:
            # deactivate MFA
            user = disable_mfa(
                user_uuid_or_name=str(user.uuid),
                otp=pyotp.TOTP(user.mfa_secret).now(),
                user=user,
                session=session,
            )
            with commit_or_rollback(session):
                user.super_user = False
                user.disabled = True
                session.add(user)


with create_mfa_enabled_super_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    display(user)
    assert user.super_user
    assert not user.disabled
    assert user.is_mfa_active
    assert user.mfa_secret is not None
    test_user_id = user.id

with get_session_with_context() as session:
    user = session.exec(select(User).where(User.id == test_user_id)).one()
    display(user)
    assert not user.super_user
    assert user.disabled
    assert not user.is_mfa_active
    assert user.mfa_secret is None

User(id=429, uuid=UUID('8ac26796-c842-4b51-a421-1e823cb0c4fd'), username='iewhikqipo', first_name='first_name_iewhikqipo', last_name='last_name_iewhikqipo', email='iewhikqipo@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=True, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 47), phone_number=None, is_phone_number_verified=False, mfa_secret=**********************************, is_mfa_active=True)

User(id=429, uuid=UUID('8ac26796-c842-4b51-a421-1e823cb0c4fd'), username='iewhikqipo', first_name='first_name_iewhikqipo', last_name='last_name_iewhikqipo', email='iewhikqipo@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=True, created=datetime.datetime(2023, 3, 10, 7, 53, 47), phone_number=None, is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)

In [None]:
# MFA enabled super user trying to create new user
with create_mfa_enabled_super_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    # Negative scenario: Not passing the OTP in the request
    with pytest.raises(HTTPException) as e:
        username = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
        user_to_create = UserCreate(
            username=username,
            first_name=f"first_name_{username}",
            last_name=f"last_name_{username}",
            email=f"{username}@email.com",
            password=username,
            subscription_type="test",
            super_user=True,
        )
        create_user(user_to_create=user_to_create, user=user, session=session)
    assert "OTP is required" in str(e.value.detail)
    display(e.value.detail)

    # Negative scenario: passing invalid OTP in the request
    with pytest.raises(HTTPException) as e:
        random_otp = 111111
        username = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
        user_to_create = UserCreate(
            username=username,
            first_name=f"first_name_{username}",
            last_name=f"last_name_{username}",
            email=f"{username}@email.com",
            password=username,
            subscription_type="test",
            super_user=True,
            otp=random_otp,
        )
        create_user(user_to_create=user_to_create, user=user, session=session)
    assert "Invalid OTP" in str(e.value.detail), str(e.value.detail)
    display(e.value.detail)

    # Positive scenario: passing valid OTP in the request
    username = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
    user_to_create = UserCreate(
        username=username,
        first_name=f"first_name_{username}",
        last_name=f"last_name_{username}",
        email=f"{username}@email.com",
        password=username,
        subscription_type="test",
        super_user=True,
        otp=pyotp.TOTP(user_and_session[0].mfa_secret).now(),
    )
    actual = create_user(user_to_create=user_to_create, user=user, session=session)
    display(actual)
    assert actual.username == username

'OTP is required. Please enter the OTP generated by the authenticator app or the one you requested via SMS.'

'Invalid OTP. Please try again.'

User(id=431, uuid=UUID('f6d93ba5-7f6a-400b-8528-56f5e71498b7'), username='lembepxoyf', first_name='first_name_lembepxoyf', last_name='last_name_lembepxoyf', email='lembepxoyf@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=True, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 48), phone_number=None, is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)

In [None]:
# Negative Scenario: Non-MFA super user trying to create a new user by passing OTP
with get_session_with_context() as session:
    with pytest.raises(HTTPException) as e:
        random_otp = 111111
        user_kumaran = session.exec(
            select(User).where(User.username == "kumaran")
        ).one()

        username = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
        user_to_create = UserCreate(
            username=username,
            first_name=f"first_name_{username}",
            last_name=f"last_name_{username}",
            email=f"{username}@email.com",
            password=username,
            subscription_type="test",
            otp=random_otp,
        )
        create_user(user_to_create=user_to_create, user=user_kumaran, session=session)

assert (
    str(e.value.detail)
    == "MFA is not activated for the account. Please pass the OTP only after activating the MFA for your account."
)
display(e.value.detail)

'MFA is not activated for the account. Please pass the OTP only after activating the MFA for your account.'

In [None]:
with get_session_with_context() as session:
    user_kumaran = session.exec(select(User).where(User.username == "kumaran")).one()

    username = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
    user_to_create = UserCreate(
        username=username,
        first_name="John",
        last_name="Wick",
        email=f"{username}@email.com",
        password=username,
        subscription_type="test",
    )
    display(user_to_create)
    actual = create_user(
        user_to_create=user_to_create, user=user_kumaran, session=session
    )
    display(actual)
    assert actual.username == user_to_create.username
    assert actual.username is not None

    apikey_to_create = APIKeyCreate(expiry=datetime.utcnow() + timedelta(days=1))
    apikey_created = create_apikey(
        apikey_to_create=apikey_to_create, user=actual, session=session
    )
    #     display(apikey_created)

    # Try to create it again
    with pytest.raises(HTTPException) as e:
        create_user(user_to_create=user_to_create, user=user_kumaran, session=session)
    display(e)

# Try to create it with user without create permission
with get_session_with_context() as session:
    not_super_user = session.exec(
        select(User).where(User.username == test_username)
    ).one()
    with pytest.raises(HTTPException) as e:
        create_user(user_to_create=user_to_create, user=not_super_user, session=session)
    display(e)

UserCreate(username='ddlcjumqyk', first_name='John', last_name='Wick', email='ddlcjumqyk@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, phone_number=None, password = '****************************************', otp=None)

User(id=432, uuid=UUID('e0d95436-d297-4f7f-beac-15eaa16d1d9c'), username='ddlcjumqyk', first_name='John', last_name='Wick', email='ddlcjumqyk@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 48), phone_number=None, is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)

<ExceptionInfo HTTPException(status_code=400, detail='The requested username or email already exists. Try another.') tblen=5>

<ExceptionInfo HTTPException(status_code=401, detail='You do not have sufficient permission to access this route. Please contact your administrator for help.') tblen=3>

In [None]:
# | exporti


@patch(cls_method=True)  # type: ignore
def get(cls: User, uuid: str, session: Session) -> User:
    """Function to get user object based on given user id

    Args:
        uuid: User uuid
        session: Sqlmodel session

    Returns:
        A user object for given uuid

    Raises:
        HTTPException: if uuid does not exists in database
    """
    try:
        user = session.exec(select(User).where(User.uuid == uuid)).one()
    except NoResultFound:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["INCORRECT_USER_UUID"],
        )

    return user

In [None]:
# | export


class UserUpdateRequest(BaseModel):
    """Request object to update user

    Args:
        username: Updated username
        first_name: Updated first name
        last_name: Updated last name
        email: Updated email
        otp: Dynamically generated six-digit verification code from the authenticator app
    """

    username: Optional[str] = None
    first_name: Optional[str] = None
    last_name: Optional[str] = None
    email: Optional[str] = None
    otp: Optional[str] = None

In [None]:
# | exporti


@patch(cls_method=True)  # type: ignore
def check_username_exists(cls: User, username: str, session: Session) -> None:
    """Check given username already exists in database or not

    Args:
        username: Username to check
        session: Sqlmodel session

    Raises:
        HTTPException: if username exists
    """
    try:
        session.exec(select(User).where(User.username == username)).one()
    except NoResultFound:
        return

    raise HTTPException(
        status_code=status.HTTP_400_BAD_REQUEST,
        detail=ERRORS["USERNAME_ALREADY_EXISTS"],
    )

In [None]:
# | exporti


def check_valid_email(email: str) -> str:
    """Check the given email is valid or not

    Args:
        email: Email to check

    Returns:
        The email, if its valid

    Raises:
        HTTPException: if email is an invalid one
    """
    email_regex = re.compile(r"[^@]+@[^@]+\.[^@]+")
    if not email_regex.match(email):
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["INVALID_EMAIL"],
        )
    return email

In [None]:
email = "valid_email@mail.com"
expected = email
actual = check_valid_email(email)
assert actual == expected, expected
actual

'valid_email@mail.com'

In [None]:
email = "invalid_email"
with pytest.raises(HTTPException) as e:
    check_valid_email(email)

assert "Invalid email" in str(e.value.detail), str(e.value.detail)
str(e.value.detail)

'Invalid email format. Please try again.'

In [None]:
# | exporti


@patch(cls_method=True)  # type: ignore
def check_email_exists(cls: User, email: str, session: Session) -> None:
    """Check given email already exists in database or not

    Args:
        email: Email to check
        session: Sqlmodel session

    Raises:
        HTTPException: if email is an invalid one or if email exists
    """

    email = check_valid_email(email)

    try:
        session.exec(select(User).where(User.email == email)).one()
    except NoResultFound:
        return

    raise HTTPException(
        status_code=status.HTTP_400_BAD_REQUEST,
        detail=ERRORS["EMAIL_ALREADY_EXISTS"],
    )

In [None]:
# | exporti


@patch  # type: ignore
def _update(self: User, to_update: UserUpdateRequest, session: Session) -> User:
    if to_update.username:
        User.check_username_exists(to_update.username, session)
        self.username = to_update.username

    if to_update.email:
        User.check_email_exists(to_update.email, session)
        self.email = to_update.email  # type: ignore

    if to_update.first_name:
        self.first_name = to_update.first_name
    if to_update.last_name:
        self.last_name = to_update.last_name

    with commit_or_rollback(session):
        session.add(self)

    return self

In [None]:
# | export


@user_router.post(
    "/{user_uuid_or_name}/update",
    response_model=UserRead,
    responses={
        400: {
            "model": HTTPError,
            "description": ERRORS["USERNAME_OR_EMAIL_ALREADY_EXISTS"],
        },
        401: {
            "model": HTTPError,
            "description": ERRORS["NOT_ENOUGH_PERMISSION"],
        },
    },
)
@require_otp_if_mfa_enabled
def update_user(
    to_update: UserUpdateRequest,
    user_uuid_or_name: str,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> User:
    """Update user"""
    user = session.merge(user)

    user_to_update = get_valid_user(user, session, user_uuid_or_name)
    return user_to_update._update(to_update, session)  # type: ignore

In [None]:
# MFA enabled normal user trying to update their details
with create_mfa_enabled_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    # Negative scenario: Not passing the OTP in the request
    with pytest.raises(HTTPException) as e:
        user_update_request = UserUpdateRequest(
            username=user.username,
            first_name="first_name_update",
            last_name="last_name_update",
            email=f"{username}@email.com",
        )

        update_user(
            to_update=user_update_request,
            user_uuid_or_name=str(user.uuid),
            user=user,
            session=session,
        )
    assert "OTP is required" in str(e.value.detail)
    display(e.value.detail)

    # Negative scenario: passing invalid OTP in the request
    with pytest.raises(HTTPException) as e:
        random_otp = 111111
        user_update_request = UserUpdateRequest(
            username=user.username,
            first_name="first_name_update",
            last_name="last_name_update",
            email=f"{username}@email.com",
            otp=random_otp,
        )

        update_user(
            to_update=user_update_request,
            user_uuid_or_name=str(user.uuid),
            user=user,
            session=session,
        )
    assert "Invalid OTP" in str(e.value.detail), str(e.value.detail)
    display(e.value.detail)

    # Positive scenario: passing valid OTP in the request
    user_update_request = UserUpdateRequest(
        first_name="updated_first_name",
        last_name="updated_last_name",
        password="a_new_password",
        otp=pyotp.TOTP(user_and_session[0].mfa_secret).now(),
    )
    actual = update_user(
        to_update=user_update_request,
        user_uuid_or_name=user.username,
        user=user,
        session=session,
    )
    display(actual)
    assert actual.first_name == "updated_first_name"
    assert actual.last_name == "updated_last_name"

    display(f"{verify_password('a_new_password', actual.password)=}")
    assert not verify_password("a_new_password", actual.password)

'OTP is required. Please enter the OTP generated by the authenticator app or the one you requested via SMS.'

'Invalid OTP. Please try again.'

User(id=434, uuid=UUID('ac8689ff-5918-403d-a4e2-79567e8f582d'), username='ceggwtvqua', first_name='updated_first_name', last_name='updated_last_name', email='ceggwtvqua@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 49), phone_number=None, is_phone_number_verified=False, mfa_secret=**********************************, is_mfa_active=True)

"verify_password('a_new_password', actual.password)=False"

In [None]:
# Negative Scenario: Non-MFA user trying to update their details by passing OTP
with get_session_with_context() as session:
    with pytest.raises(HTTPException) as e:
        random_otp = 111111
        test_username_for_update = create_user_for_testing()
        user_to_update = session.exec(
            select(User).where(User.username == test_username_for_update)
        ).one()

        username = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
        user_update_request = UserUpdateRequest(
            username=username,
            first_name="first_name_update",
            last_name="last_name_update",
            email=f"{username}@email.com",
            password="a_new_password",
            otp=random_otp,
        )

        update_user(
            to_update=user_update_request,
            user_uuid_or_name=str(user_to_update.uuid),
            user=user_to_update,
            session=session,
        )
assert (
    str(e.value.detail)
    == "MFA is not activated for the account. Please pass the OTP only after activating the MFA for your account."
)
display(e.value.detail)

'MFA is not activated for the account. Please pass the OTP only after activating the MFA for your account.'

In [None]:
with get_session_with_context() as session:
    test_username_for_update = create_user_for_testing()
    user_to_update = session.exec(
        select(User).where(User.username == test_username_for_update)
    ).one()

    username = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
    user_update_request = UserUpdateRequest(
        username=username,
        first_name="first_name_update",
        last_name="last_name_update",
        email=f"{username}@email.com",
    )

    actual = update_user(
        to_update=user_update_request,
        user_uuid_or_name=str(user_to_update.uuid),
        user=user_to_update,
        session=session,
    )

    display(f"{actual.username=}")

    assert actual.id == user_to_update.id
    assert actual.username == user_update_request.username
    assert actual.first_name == user_update_request.first_name
    assert actual.last_name == user_update_request.last_name
    assert actual.email == user_update_request.email

"actual.username='haqaufhacq'"

In [None]:
with get_session_with_context() as session:
    test_username_for_update = create_user_for_testing()
    user_to_update = session.exec(
        select(User).where(User.username == test_username_for_update)
    ).one()

    user_update_request = UserUpdateRequest(
        username="".join(random.choice(string.ascii_lowercase) for _ in range(10)),
        first_name="first_name_update",
        last_name="last_name_update",
        email="not_a_valid_email",
    )
    with pytest.raises(HTTPException) as e:
        actual = update_user(
            to_update=user_update_request,
            user_uuid_or_name=str(user_to_update.uuid),
            user=user_to_update,
            session=session,
        )
    display(e)

    user_update_request = UserUpdateRequest(
        username="kumaran",
        first_name="first_name_update",
        last_name="last_name_update",
        email="not_a_valid_email",
    )
    with pytest.raises(HTTPException) as e:
        actual = update_user(
            to_update=user_update_request,
            user_uuid_or_name=str(user_to_update.uuid),
            user=user_to_update,
            session=session,
        )
    display(e)

    user_update_request = UserUpdateRequest(
        username="".join(random.choice(string.ascii_lowercase) for _ in range(10)),
        first_name="first_name_update",
        last_name="last_name_update",
        email="kumaran@airt.ai",
    )
    with pytest.raises(HTTPException) as e:
        actual = update_user(
            to_update=user_update_request,
            user_uuid_or_name=str(user_to_update.uuid),
            user=user_to_update,
            session=session,
        )
    display(e)

<ExceptionInfo HTTPException(status_code=400, detail='Invalid email format. Please try again.') tblen=6>

<ExceptionInfo HTTPException(status_code=400, detail='The requested username is already taken. Try another.') tblen=5>

<ExceptionInfo HTTPException(status_code=400, detail='The requested Email is already taken. Try another.') tblen=5>

In [None]:
# | exporti


@patch  # type: ignore
def disable(self: User, session: Session) -> User:
    """Disable user

    Args:
        session: Sqlmodel session

    Returns:
        The disabled user object

    Raises:
        HTTPException: if user is already disabled
    """
    if self.disabled:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["USER_ALREADY_DISABLED"],
        )

    with commit_or_rollback(session):
        self.disabled = True
        session.add(self)

        for apikey in self.apikeys:
            apikey.disabled = True
            session.add(apikey)
    return self

In [None]:
# | export


@user_router.delete(
    "/{user_uuid_or_name}",
    response_model=UserRead,
    responses={
        400: {"model": HTTPError, "description": ERRORS["INCORRECT_USER_UUID"]},
        401: {
            "model": HTTPError,
            "description": ERRORS["NOT_ENOUGH_PERMISSION"],
        },
    },
)
@require_otp_if_mfa_enabled
@ensure_super_user
def disable_user(
    user_uuid_or_name: str,
    otp: Optional[str] = None,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> User:
    """Disable user"""
    user = session.merge(user)

    user_to_disable = get_valid_user(user, session, user_uuid_or_name)

    return user_to_disable.disable(session)

In [None]:
# MFA enabled super user trying to disable a user
with create_mfa_enabled_super_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    # Negative scenario: Not passing the OTP in the request
    with pytest.raises(HTTPException) as e:
        user_to_disable = session.exec(
            select(User).where(User.username == user_to_create.username)
        ).one()

        disable_user(
            user_uuid_or_name=str(user_to_disable.uuid), user=user, session=session
        )
    assert "OTP is required" in str(e.value.detail)
    display(e.value.detail)

    # Negative scenario: passing invalid OTP in the request
    with pytest.raises(HTTPException) as e:
        random_otp = 111111
        user_to_disable = session.exec(
            select(User).where(User.username == user_to_create.username)
        ).one()

        disable_user(
            user_uuid_or_name=str(user_to_disable.uuid),
            otp=random_otp,
            user=user,
            session=session,
        )
    assert "Invalid OTP" in str(e.value.detail), str(e.value.detail)
    display(e.value.detail)

    # Positive scenario: passing valid OTP in the request
    username = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
    user_to_create_request = UserCreate(
        username=username,
        first_name=f"first_name_{username}",
        last_name=f"last_name_{username}",
        email=f"{username}@email.com",
        password=username,
        subscription_type="test",
        otp=pyotp.TOTP(user_and_session[0].mfa_secret).now(),
    )
    sample_user = create_user(
        user_to_create=user_to_create_request, user=user, session=session
    )
    user_to_disable = session.exec(
        select(User).where(User.username == sample_user.username)
    ).one()

    actual = disable_user(
        user_uuid_or_name=str(user_to_disable.username),
        otp=pyotp.TOTP(user_and_session[0].mfa_secret).now(),
        user=user,
        session=session,
    )
    display(actual)
    assert actual.username == sample_user.username

    # For following test cases
    sample_user_id = user.id

'OTP is required. Please enter the OTP generated by the authenticator app or the one you requested via SMS.'

'Invalid OTP. Please try again.'

User(id=439, uuid=UUID('263d1ccc-92a0-49fb-b13e-92c2bf378515'), username='fiaavrfqpt', first_name='first_name_fiaavrfqpt', last_name='last_name_fiaavrfqpt', email='fiaavrfqpt@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=True, created=datetime.datetime(2023, 3, 10, 7, 53, 50), phone_number=None, is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)

In [None]:
# Negative Scenario: Non-MFA enabled super user trying to disable new user by passing OTP
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.id == sample_user_id)).one()
    with pytest.raises(HTTPException) as e:
        random_otp = 111111
        user_to_disable = session.exec(
            select(User).where(User.username == user_to_create.username)
        ).one()

        disable_user(
            user_uuid_or_name=str(user_to_disable.uuid),
            otp=random_otp,
            user=user,
            session=session,
        )
    assert (
        str(e.value.detail)
        == "MFA is not activated for the account. Please pass the OTP only after activating the MFA for your account."
    )
    display(e.value.detail)

'MFA is not activated for the account. Please pass the OTP only after activating the MFA for your account.'

In [None]:
with get_session_with_context() as session:
    user_kumaran = session.exec(select(User).where(User.username == "kumaran")).one()

    user_to_disable = session.exec(
        select(User).where(User.username == user_to_create.username)
    ).one()

    # Try to do it with user without create permission
    not_super_user = session.exec(
        select(User).where(User.username == test_username)
    ).one()
    with pytest.raises(HTTPException) as e:
        disable_user(
            user_uuid_or_name=str(user_to_disable.uuid),
            user=not_super_user,
            session=session,
        )

    actual = disable_user(
        user_uuid_or_name=str(user_to_disable.uuid), user=user_kumaran, session=session
    )
    display(actual)
    assert actual.disabled == True

    for apikey in actual.apikeys:
        display(apikey)
        assert apikey.disabled == True

User(id=432, uuid=UUID('e0d95436-d297-4f7f-beac-15eaa16d1d9c'), username='ddlcjumqyk', first_name='John', last_name='Wick', email='ddlcjumqyk@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=True, created=datetime.datetime(2023, 3, 10, 7, 53, 48), phone_number=None, is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)

APIKey(name=None, uuid=UUID('7426a0e8-96d2-4a4e-a650-f7f34ec7ba44'), created=datetime.datetime(2023, 3, 10, 7, 53, 48), id=51, expiry=datetime.datetime(2023, 3, 11, 7, 53, 48), disabled=True, user_id=432)

In [None]:
# | exporti


@patch  # type: ignore
def enable(self: User, session: Session) -> User:
    """Enable user

    Args:
        session: Sqlmodel session

    Returns:
        The enabled user object
    """
    if not self.disabled:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["USER_ALREADY_ENABLED"],
        )
    with commit_or_rollback(session):
        self.disabled = False
        session.add(self)
    return self

In [None]:
# | export


@user_router.get(
    "/{user_uuid_or_name}/enable",
    response_model=UserRead,
    responses={
        400: {"model": HTTPError, "description": ERRORS["INCORRECT_USER_UUID"]},
        401: {
            "model": HTTPError,
            "description": ERRORS["NOT_ENOUGH_PERMISSION"],
        },
    },
)
@require_otp_if_mfa_enabled
@ensure_super_user
def enable_user(
    user_uuid_or_name: str,
    otp: Optional[str] = None,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> User:
    """Enable user"""
    user = session.merge(user)
    user_to_enable = get_valid_user(user, session, user_uuid_or_name)

    return user_to_enable.enable(session)

In [None]:
# MFA enabled super user trying to enable a user
with create_mfa_enabled_super_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]

    disabled_user = session.exec(select(User).where(User.id == sample_user_id)).one()
    # Negative scenario: Not passing the OTP in the request
    with pytest.raises(HTTPException) as e:
        enable_user(
            user_uuid_or_name=str(disabled_user.uuid), user=user, session=session
        )
    assert "OTP is required" in str(e.value.detail)
    display(e.value.detail)

    # Negative scenario: passing invalid OTP in the request
    with pytest.raises(HTTPException) as e:
        random_otp = 111111
        enable_user(
            user_uuid_or_name=str(disabled_user.uuid),
            otp=random_otp,
            user=user,
            session=session,
        )
    assert "Invalid OTP" in str(e.value.detail), str(e.value.detail)
    display(e.value.detail)

    # Positive scenario: passing valid OTP in the request
    username = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
    user_to_create_request = UserCreate(
        username=username,
        first_name=f"first_name_{username}",
        last_name=f"last_name_{username}",
        email=f"{username}@email.com",
        password=username,
        subscription_type="test",
        otp=pyotp.TOTP(user_and_session[0].mfa_secret).now(),
    )
    sample_user = create_user(
        user_to_create=user_to_create_request, user=user, session=session
    )
    user_to_enable = session.exec(
        select(User).where(User.username == sample_user.username)
    ).one()
    user_to_enable.disabled = True
    session.add(user_to_enable)
    session.commit()
    session.refresh(user_to_enable)

    actual = enable_user(
        user_uuid_or_name=str(user_to_enable.username),
        otp=pyotp.TOTP(user_and_session[0].mfa_secret).now(),
        user=user,
        session=session,
    )
    display(actual)
    assert actual.username == user_to_enable.username

    # For following test case
    test_user_id = user.id
    user_to_enable_uuid = str(user_to_enable.uuid)

'OTP is required. Please enter the OTP generated by the authenticator app or the one you requested via SMS.'

'Invalid OTP. Please try again.'

User(id=441, uuid=UUID('95b5533d-f0b3-4cc8-b3bf-7540bf3498a4'), username='eljnfxpxfa', first_name='first_name_eljnfxpxfa', last_name='last_name_eljnfxpxfa', email='eljnfxpxfa@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 51), phone_number=None, is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)

In [None]:
# Negative Scenario: Non-MFA enabled super user trying to enable a user by passing OTP
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.id == test_user_id)).one()
    with pytest.raises(HTTPException) as e:
        random_otp = 111111
        user_to_disable = session.exec(
            select(User).where(User.username == user_to_create.username)
        ).one()

        enable_user(
            user_uuid_or_name=user_to_enable_uuid,
            otp=random_otp,
            user=user,
            session=session,
        )
assert (
    str(e.value.detail)
    == "MFA is not activated for the account. Please pass the OTP only after activating the MFA for your account."
)
display(e.value.detail)

'MFA is not activated for the account. Please pass the OTP only after activating the MFA for your account.'

In [None]:
with get_session_with_context() as session:
    user_kumaran = session.exec(select(User).where(User.username == "kumaran")).one()

    username = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
    user_to_create = UserCreate(
        username=username,
        first_name="John",
        last_name="Wick",
        email=f"{username}@email.com",
        password=username,
        subscription_type="test",
    )
    user_to_enable = create_user(
        user_to_create=user_to_create, user=user_kumaran, session=session
    )
    user_to_enable.disabled = True
    session.add(user_to_enable)
    session.commit()
    session.refresh(user_to_enable)

    # Try to do it with user without create permission
    not_super_user = session.exec(
        select(User).where(User.username == test_username)
    ).one()
    with pytest.raises(HTTPException) as e:
        enable_user(
            user_uuid=str(user_to_enable.uuid), user=not_super_user, session=session
        )
    display(e)

    actual = enable_user(
        user_uuid_or_name=str(user_to_enable.uuid), user=user_kumaran, session=session
    )
    display(actual)
    assert actual.disabled == False

    for apikey in actual.apikeys:
        display(apikey)
        assert apikey.disabled == True

<ExceptionInfo HTTPException(status_code=401, detail='You do not have sufficient permission to access this route. Please contact your administrator for help.') tblen=3>

User(id=442, uuid=UUID('67ec79bf-6c0e-426d-84ee-07ea6b736130'), username='dikpthchcr', first_name='John', last_name='Wick', email='dikpthchcr@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 51), phone_number=None, is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)

In [None]:
# | exporti


@patch(cls_method=True)  # type: ignore
def get_all(
    cls: User,
    disabled: bool,
    offset: int,
    limit: int,
    session: Session,
) -> List[User]:
    """Function to get all users

    Args:
        disabled: Whether to get only disabled users
        offset: Offset results by given integer
        limit: Limit results by given integer
        session: Sqlmodel session

    Returns:
        a list of user objects
    """
    statement = select(User)
    statement = statement.where(User.disabled == disabled)

    return session.exec(statement.offset(offset).limit(limit)).all()

In [None]:
# | export


@user_router.get(
    "/",
    response_model=List[UserRead],
    responses={
        401: {
            "model": HTTPError,
            "description": ERRORS["NOT_ENOUGH_PERMISSION"],
        },
    },
)
@ensure_super_user
def get_all_users(
    disabled: bool = False,
    offset: int = 0,
    limit: int = Query(default=100, lte=100),
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> List[User]:
    """Get all users"""
    user = session.merge(user)

    return User.get_all(disabled=disabled, offset=offset, limit=limit, session=session)

In [None]:
# MFA enabled super user trying to get all users
with create_mfa_enabled_super_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]

    actual = get_all_users(
        disabled=False,
        offset=0,
        limit=1,
        user=user,
        session=session,
    )
    display(actual)
    assert len(actual) > 0

[User(id=1, uuid=UUID('b660dce5-f500-4a5d-86ef-8f81372d368a'), username='johndoe', first_name='John', last_name='Doe', email='johndoe@airt.ai', subscription_type=<SubscriptionType.small: 'small'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 9, 8, 17, 31), phone_number=None, is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)]

In [None]:
with get_session_with_context() as session:
    user_kumaran = session.exec(select(User).where(User.username == "kumaran")).one()

    actual = get_all_users(
        disabled=False, offset=0, limit=1, user=user_kumaran, session=session
    )
    display(actual)
    assert actual[0].username == "johndoe"
    assert not actual[0].disabled

    actual = get_all_users(
        disabled=True, offset=0, limit=1, user=user_kumaran, session=session
    )
    display(actual)
    assert actual[0].disabled

# Try to get it with user without create permission
with get_session_with_context() as session:
    not_super_user = session.exec(
        select(User).where(User.username == test_username)
    ).one()
    with pytest.raises(HTTPException) as e:
        get_all_users(offset=0, limit=1, user=not_super_user, session=session)
display(e)

[User(id=1, uuid=UUID('b660dce5-f500-4a5d-86ef-8f81372d368a'), username='johndoe', first_name='John', last_name='Doe', email='johndoe@airt.ai', subscription_type=<SubscriptionType.small: 'small'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 9, 8, 17, 31), phone_number=None, is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)]

[User(id=53, uuid=UUID('fc5cfad1-a2fb-4fbf-952b-97ae7884d11e'), username='vyfohgucmp', first_name='unittest', last_name='user', email='vyfohgucmp@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=True, created=datetime.datetime(2023, 3, 9, 11, 46, 14), phone_number=None, is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)]

<ExceptionInfo HTTPException(status_code=401, detail='You do not have sufficient permission to access this route. Please contact your administrator for help.') tblen=2>

In [None]:
# | export


@user_router.get("/details", response_model=UserRead)
def get_user_details(
    user_uuid_or_name: Optional[str] = None,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> User:
    """Get user details"""
    user = session.merge(user)

    _user = (
        get_valid_user(user, session, user_uuid_or_name)
        if user_uuid_or_name is not None
        else user
    )

    return User.get(_user.uuid, session)  # type: ignore

In [None]:
# MFA enabled user trying to get their details
with create_mfa_enabled_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]

    actual = get_user_details(user_uuid_or_name=None, user=user, session=session)
    display(actual)
    assert actual.username == user.username

User(id=444, uuid=UUID('f4e217e1-48b3-4f5d-afe1-d2d3923ea849'), username='rmputetoaj', first_name='unittest', last_name='user', email='rmputetoaj@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 51), phone_number=None, is_phone_number_verified=False, mfa_secret=**********************************, is_mfa_active=True)

In [None]:
# Normal user getting their details. Passing the user_id_or_name as None
with get_session_with_context() as session:
    not_super_user = session.exec(
        select(User).where(User.username == test_username)
    ).one()

    actual = get_user_details(
        user_uuid_or_name=None,
        user=not_super_user,
        session=session,
    )
    assert actual.id == not_super_user.id
    display(actual)

    # Normal user getting their details
    actual = get_user_details(
        user_uuid_or_name=str(not_super_user.uuid),
        user=not_super_user,
        session=session,
    )
    assert actual.id == not_super_user.id
    display(actual)

    # Normal user getting other user's details
    with pytest.raises(HTTPException) as e:
        get_user_details(
            user_uuid_or_name="random_user_name",
            user=not_super_user,
            session=session,
        )

    assert e.value.detail == "Insufficient permission to access other user's data"
    display(e.value.detail)

# Super user getting their details
with get_session_with_context() as session:
    user_kumaran = session.exec(select(User).where(User.username == "kumaran")).one()

    actual = get_user_details(
        user_uuid_or_name=str(user_kumaran.username),
        user=user_kumaran,
        session=session,
    )
    assert actual.id == user_kumaran.id

    actual = get_user_details(
        user_uuid_or_name=None,
        user=user_kumaran,
        session=session,
    )
    assert actual.id == user_kumaran.id

    # Super user getting other user's details
    actual = get_user_details(
        user_uuid_or_name=str(not_super_user.username),
        user=user_kumaran,
        session=session,
    )
    assert actual.id == not_super_user.id

    # Super user getting invalid user's details
    with pytest.raises(HTTPException) as e:
        get_user_details(
            user_uuid_or_name=INVALID_UUID_FOR_TESTING,
            user=user_kumaran,
            session=session,
        )

    assert "The user uuid is incorrect" in e.value.detail, e.value.detail
    display(e.value.detail)

User(id=425, uuid=UUID('9e7bfe37-bc47-483e-aae8-5d6b1a635cc1'), username='sqfarrgeoo', first_name='unittest', last_name='user', email='sqfarrgeoo@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 34), phone_number=None, is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)

User(id=425, uuid=UUID('9e7bfe37-bc47-483e-aae8-5d6b1a635cc1'), username='sqfarrgeoo', first_name='unittest', last_name='user', email='sqfarrgeoo@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 34), phone_number=None, is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)

"Insufficient permission to access other user's data"

'The user uuid is incorrect. Please try again.'

In [None]:
# | exporti


@patch  # type: ignore
def enable(self: SSO, session: Session, sso_email: EmailStr) -> SSO:
    """Enable SSO for a particular service

    Args:
        session: Sqlmodel session
        sso_email: Email address to enable SSO for this provider

    Returns:
        The enabled SSO object
    """
    with commit_or_rollback(session):
        self.disabled = False
        self.sso_email = sso_email
        session.add(self)

    return self

In [None]:
# | exporti


def check_valid_sso_provider(sso_provider: str) -> str:
    """Validate if the given sso_provider

    Args:
        sso_provider: SSO provider name

    Returns:
        The sso_provider if it is valid

    Raises:
        HTTPException: If the sso_provider didn't match the allowed values
    """
    valid_sso_providers = [e.value for e in SSOProvider]
    if sso_provider not in valid_sso_providers:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=f'{ERRORS["INVALID_SSO_PROVIDER"]}: {valid_sso_providers}',
        )
    return sso_provider

In [None]:
sso_providers = ["google", "github"]
for sp in sso_providers:
    actual = check_valid_sso_provider(sp)
    assert actual == sp

invalid_sso_provider = ["invalid_sso_provider"]
with pytest.raises(HTTPException) as e:
    check_valid_sso_provider(invalid_sso_provider)
assert "Invalid SSO provider" in e.value.detail
display(e.value.detail)

"Invalid SSO provider. Valid SSO providers are: ['google', 'github']"

In [None]:
# | export


class EnableSSORequest(SSOBase):
    """A base class for enabling sso for the account

    Args:
        otp: OTP passed by the user
    """

    otp: Optional[str] = None

    @validator("sso_provider", pre=True)
    @classmethod
    def validate_sso_provider(cls, sso_provider: str) -> str:
        return check_valid_sso_provider(sso_provider)

    @validator("sso_email", pre=True)
    @classmethod
    def validate_email(cls, sso_email: str) -> str:
        if sso_email is not None:
            sso_email = check_valid_email(sso_email)
        return sso_email

In [None]:
# | export


@user_router.post("/sso/enable", response_model=SSORead)
@require_otp_if_mfa_enabled
def enable_sso(
    enable_sso_request: EnableSSORequest,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> SSO:
    """Enable SSO for the user"""
    user = session.merge(user)

    sso_email = enable_sso_request.sso_email
    sso_provider = enable_sso_request.sso_provider

    sso = session.exec(
        select(SSO).where(SSO.user == user).where(SSO.sso_provider == sso_provider)
    ).one_or_none()

    if sso is not None:
        return sso.enable(session, sso_email)  # type: ignore

    if sso_email is None:
        sso_email = user.email

    with commit_or_rollback(session):
        _sso = SSO(sso_provider=sso_provider, sso_email=sso_email)
        _sso.user = user
        session.add(_sso)

    return _sso

In [None]:
with get_session_with_context() as session:
    test_username = create_user_for_testing()
    user = session.exec(select(User).where(User.username == test_username)).one()
    enable_sso_request = EnableSSORequest(
        sso_provider="google",
        sso_email="random_email_id@mail.com",
    )
    actual = enable_sso(
        enable_sso_request=enable_sso_request, user=user, session=session
    )
    assert actual.sso_email == enable_sso_request.sso_email
display(actual)

SSO(id=80, sso_provider=<SSOProvider.google: 'google'>, username='slnynamced', sso_email='random_email_id@mail.com', disabled=False)

In [None]:
# Non-MFA enabled user trying to enable SSO
with get_session_with_context() as session:
    test_username = create_user_for_testing()
    user = session.exec(select(User).where(User.username == test_username)).one()

    # Negative scenario: passing invalid SSO provider in the request
    with pytest.raises(HTTPException) as e:
        enable_sso_request = EnableSSORequest(
            sso_provider="random_sso_provider",
            sso_email="random_email_id@mail.com",
            otp=random_otp,
        )
    assert "Invalid SSO provider" in e.value.detail
    display(e.value.detail)

    # Negative scenario: passing invalid email in the request
    with pytest.raises(HTTPException) as e:
        enable_sso_request = EnableSSORequest(
            sso_provider="google", sso_email="invalid_email.com", otp=random_otp
        )
    assert "Invalid email" in e.value.detail
    display(e.value.detail)

    # Negative scenario: passing OTP in the request
    with pytest.raises(HTTPException) as e:
        random_otp = 111111
        enable_sso_request = EnableSSORequest(
            sso_provider="google", sso_email="random_email_id@mail.com", otp=random_otp
        )
        enable_sso(enable_sso_request=enable_sso_request, user=user, session=session)
    assert "MFA is not activated for the account" in str(e.value.detail), str(
        e.value.detail
    )
    display(e.value.detail)

# Positive scenario: passing valid email in the request
with get_session_with_context() as session:
    test_username = create_user_for_testing()
    user = session.exec(select(User).where(User.username == test_username)).one()
    enable_sso_request = EnableSSORequest(
        sso_provider="google",
        sso_email="random_email_id@mail.com",
    )
    actual = enable_sso(
        enable_sso_request=enable_sso_request, user=user, session=session
    )
    assert actual.sso_email == enable_sso_request.sso_email
display(actual)

# Positive scenario: not passing email in the request, should take the one from our records
with get_session_with_context() as session:
    test_username = create_user_for_testing()
    user = session.exec(select(User).where(User.username == test_username)).one()
    enable_sso_request = EnableSSORequest(
        sso_provider="google",
        sso_email=None,
    )
    actual = enable_sso(
        enable_sso_request=enable_sso_request, user=user, session=session
    )
    assert actual.sso_email == user.email
display(actual)

"Invalid SSO provider. Valid SSO providers are: ['google', 'github']"

'Invalid email format. Please try again.'

'MFA is not activated for the account. Please pass the OTP only after activating the MFA for your account.'

SSO(id=81, sso_provider=<SSOProvider.google: 'google'>, username='ddlnsigwcr', sso_email='random_email_id@mail.com', disabled=False)

SSO(id=82, sso_provider=<SSOProvider.google: 'google'>, username='kzssaplnha', sso_email='kzssaplnha@email.com', disabled=False)

In [None]:
# MFA enabled user trying to enable SSO
with create_mfa_enabled_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    # Positive scenario: passing valid OTP in the request
    enable_sso_request = EnableSSORequest(
        sso_provider="google",
        sso_email="random_email_id@mail.com",
        otp=pyotp.TOTP(user_and_session[0].mfa_secret).now(),
    )
    actual = enable_sso(
        enable_sso_request=enable_sso_request, user=user, session=session
    )
    display(actual)

    # Negative scenario: passing already enabled sso_provider name
    new_sso_email = "new_random_email_id@mail.com"
    enable_sso_request = EnableSSORequest(
        sso_provider="google",
        sso_email=new_sso_email,
        otp=pyotp.TOTP(user_and_session[0].mfa_secret).now(),
    )
    actual = enable_sso(
        enable_sso_request=enable_sso_request, user=user, session=session
    )
    assert actual.sso_email == new_sso_email
    display(actual)

SSO()

SSO(id=83, sso_provider=<SSOProvider.google: 'google'>, username='igxqvcckai', sso_email='new_random_email_id@mail.com', disabled=False)

In [None]:
# MFA enabled user trying to enable SSO
with create_mfa_enabled_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    # Negative scenario: Not passing the OTP in the request
    with pytest.raises(HTTPException) as e:
        enable_sso_request = EnableSSORequest(
            sso_provider="google", sso_email="random_email_id@mail.com"
        )
        enable_sso(enable_sso_request=enable_sso_request, user=user, session=session)
    assert "OTP is required" in str(e.value.detail)
    display(e.value.detail)

    # Negative scenario: passing invalid OTP in the request
    with pytest.raises(HTTPException) as e:
        random_otp = 111111
        enable_sso_request = EnableSSORequest(
            sso_provider="google", sso_email="random_email_id@mail.com", otp=random_otp
        )
        enable_sso(enable_sso_request=enable_sso_request, user=user, session=session)
    assert "Invalid OTP" in str(e.value.detail), str(e.value.detail)
    display(e.value.detail)

    # Positive scenario: passing valid OTP in the request
    enable_sso_request = EnableSSORequest(
        sso_provider="google",
        sso_email="random_email_id@mail.com",
        otp=pyotp.TOTP(user_and_session[0].mfa_secret).now(),
    )
    actual = enable_sso(
        enable_sso_request=enable_sso_request, user=user, session=session
    )

    assert actual.sso_email == enable_sso_request.sso_email
display(actual)

'OTP is required. Please enter the OTP generated by the authenticator app or the one you requested via SMS.'

'Invalid OTP. Please try again.'

SSO()

In [None]:
# | exporti


@patch  # type: ignore
def disable(self: SSO, session: Session):
    """Disable SSO for a particular service

    Args:
        session: Sqlmodel session

    Returns:
        The disabled SSO object

    Raises:
        HTTPException: if SSO is already disabled
    """
    if self.disabled:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["SSO_ALREADY_DISABLED"],
        )

    with commit_or_rollback(session):
        self.disabled = True
        session.add(self)

    return self

In [None]:
# | export


@user_router.delete(
    "/sso/{user_uuid_or_name}/disable/{sso_provider}",
    response_model=SSORead,
    responses={
        400: {
            "model": HTTPError,
            "description": ERRORS["INCORRECT_USER_UUID"],
        },
        401: {
            "model": HTTPError,
            "description": ERRORS["NOT_ENOUGH_PERMISSION_TO_ACCESS_OTHERS_DATA"],
        },
    },
)
@require_otp_if_mfa_enabled
def disable_sso(
    user_uuid_or_name: str,
    sso_provider: str,
    otp: Optional[str] = None,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> SSO:
    """Disable SSO"""

    user = session.merge(user)

    sso_provider = check_valid_sso_provider(sso_provider)

    user_to_disable_sso = get_valid_user(user, session, user_uuid_or_name)

    try:
        sso_provider_to_disable = session.exec(
            select(SSO)
            .where(SSO.user == user_to_disable_sso)
            .where(SSO.sso_provider == sso_provider)
        ).one()

    except NoResultFound:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["SSO_NOT_ENABLED_FOR_SERVICE"],
        )

    return sso_provider_to_disable.disable(session)  # type: ignore

In [None]:
# MFA enabled user trying to disable SSO
with create_mfa_enabled_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    sso_provider = "google"

    # Negative scenario: normal user trying to disable sso for others
    with pytest.raises(HTTPException) as e:
        otp = pyotp.TOTP(user_and_session[0].mfa_secret).now()
        disable_sso(
            user_uuid_or_name=INVALID_UUID_FOR_TESTING,
            sso_provider=sso_provider,
            otp=otp,
            user=user,
            session=session,
        )
    assert "Insufficient permission" in e.value.detail, e.value.detail
    display(e.value.detail)

    # Negative scenario: Not passing the OTP in the request
    with pytest.raises(HTTPException) as e:
        disable_sso(
            user_uuid_or_name=str(user.uuid),
            sso_provider=sso_provider,
            otp=None,
            user=user,
            session=session,
        )
    assert "OTP is required" in str(e.value.detail), str(e.value.detail)
    display(e.value.detail)

    # Negative scenario: Passing invalid OTP in the request
    with pytest.raises(HTTPException) as e:
        invalid_otp = 111111
        disable_sso(
            user_uuid=str(user.uuid),
            sso_provider=sso_provider,
            otp=invalid_otp,
            user=user,
            session=session,
        )
    assert "Invalid OTP" in str(e.value.detail), str(e.value.detail)
    display(e.value.detail)

with create_mfa_enabled_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    sso_provider = "google"
    # Positive scenario: disabling SSO for particular service
    enable_sso_request = EnableSSORequest(
        sso_provider=sso_provider,
        sso_email="random_email_id@mail.com",
        otp=pyotp.TOTP(user_and_session[0].mfa_secret).now(),
    )
    enabled_sso_provider = enable_sso(
        enable_sso_request=enable_sso_request, user=user, session=session
    )
    display(enabled_sso_provider)

    # Negative scenario: passing invalid sso_provider
    with pytest.raises(HTTPException) as e:
        disable_sso(
            user_uuid_or_name=str(user.uuid),
            sso_provider="invalid_sso_provider",
            otp=pyotp.TOTP(user_and_session[0].mfa_secret).now(),
            user=user,
            session=session,
        )
    assert "Invalid SSO provider" in e.value.detail
    display(e.value.detail)

    # Positive scenario: disabling SSO for particular service
    actual = disable_sso(
        user_uuid_or_name=str(user.uuid),
        sso_provider=sso_provider,
        otp=pyotp.TOTP(user_and_session[0].mfa_secret).now(),
        user=user,
        session=session,
    )
    display(actual)
    assert actual.sso_provider == sso_provider
    assert actual.user_id == user.id

    # Negative scenario: disabling already disabled SSO provider
    with pytest.raises(HTTPException) as e:
        disable_sso(
            user_uuid_or_name=str(user.uuid),
            sso_provider=sso_provider,
            otp=pyotp.TOTP(user_and_session[0].mfa_secret).now(),
            user=user,
            session=session,
        )
    assert "SSO is already disabled" in str(e.value.detail), str(e.value.detail)
    display(e.value.detail)

with create_mfa_enabled_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    sso_provider = "google"
    # Positive scenario: enabling a disabled SSO provider
    new_sso_email = "new_email_address@mail.com"
    enable_sso_request = EnableSSORequest(
        sso_provider=sso_provider,
        sso_email=new_sso_email,
        otp=pyotp.TOTP(user_and_session[0].mfa_secret).now(),
    )
    enabled_sso_provider = enable_sso(
        enable_sso_request=enable_sso_request, user=user, session=session
    )
    display(enabled_sso_provider)
    assert (
        enabled_sso_provider.sso_email == new_sso_email
    ), enabled_sso_provider.sso_email

"Insufficient permission to access other user's data"

'OTP is required. Please enter the OTP generated by the authenticator app or the one you requested via SMS.'

'Invalid OTP. Please try again.'

SSO()

"Invalid SSO provider. Valid SSO providers are: ['google', 'github']"

SSO()

'SSO is already disabled for the provider.'

SSO()

In [None]:
# MFA enabled super user trying to disable SSO
with get_session_with_context() as session:
    normal_user = session.exec(
        select(User).where(User.username == create_user_for_testing())
    ).one()
    enable_sso_request = EnableSSORequest(
        sso_provider="google", sso_email="random_email_id@mail.com"
    )
    actual = enable_sso(
        enable_sso_request=enable_sso_request, user=normal_user, session=session
    )
    display(actual)

    assert actual.sso_provider == enable_sso_request.sso_provider
    assert actual.sso_email == enable_sso_request.sso_email
    assert actual.user_id == normal_user.id


with create_mfa_enabled_super_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    display(user)
    sso_provider = "google"
    # Positive Scenario: disabling SSO for others
    otp = pyotp.TOTP(user.mfa_secret).now()
    actual = disable_sso(
        user_uuid_or_name=str(normal_user.uuid),
        sso_provider=sso_provider,
        otp=otp,
        user=user,
        session=session,
    )
    display(actual)

    assert actual.sso_provider == sso_provider
    assert actual.disabled

    # Positive Scenario: disabling SSO for self
    enable_sso_request = EnableSSORequest(
        sso_provider="google",
        sso_email="random_email_id@mail.com",
        otp=pyotp.TOTP(user.mfa_secret).now(),
    )
    actual = enable_sso(
        enable_sso_request=enable_sso_request, user=user, session=session
    )
    display(actual)
    assert actual.sso_email == enable_sso_request.sso_email

    otp = pyotp.TOTP(user.mfa_secret).now()
    actual = disable_sso(
        user_uuid_or_name=str(user.uuid),
        sso_provider=sso_provider,
        otp=otp,
        user=user,
        session=session,
    )
    display(actual)
    assert actual.sso_provider == sso_provider
    assert actual.disabled

SSO()

User(id=455, uuid=UUID('86513363-31b6-4107-a8b1-9a2e31fe1dde'), username='imyjigngrw', first_name='first_name_imyjigngrw', last_name='last_name_imyjigngrw', email='imyjigngrw@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=True, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 54), phone_number=None, is_phone_number_verified=False, mfa_secret=**********************************, is_mfa_active=True)

SSO()

SSO()

SSO()

In [None]:
# Non-MFA user trying to disable SSO
with get_session_with_context() as session:
    username_for_testing = create_user_for_testing()
    normal_user = session.exec(
        select(User).where(User.username == username_for_testing)
    ).one()
    sso_provider = "google"
    enable_sso_request = EnableSSORequest(
        sso_provider=sso_provider, sso_email="random_email_id@mail.com"
    )
    actual = enable_sso(
        enable_sso_request=enable_sso_request, user=normal_user, session=session
    )
    display(actual)
    assert actual.sso_provider == sso_provider
    assert actual.sso_email == enable_sso_request.sso_email
    assert actual.user_id == normal_user.id
    assert not actual.disabled

    # Positive Scenario: disabling SSO for self
    actual = disable_sso(
        user_uuid_or_name=str(normal_user.uuid),
        sso_provider=sso_provider,
        otp=None,
        user=normal_user,
        session=session,
    )
    display(actual)
    assert actual.sso_provider == sso_provider
    assert actual.disabled

    # Negative Scenario: disabling SSO for already SSO disabled user
    with pytest.raises(HTTPException) as e:
        disable_sso(
            user_uuid_or_name=str(normal_user.uuid),
            sso_provider=sso_provider,
            otp=None,
            user=normal_user,
            session=session,
        )
    display(e.value.detail)
    assert "SSO is already disabled" in str(e.value.detail)

    # Negative Scenario: passing OTP
    with pytest.raises(HTTPException) as e:
        random_otp = 123456
        disable_sso(
            user_uuid_or_name=str(normal_user.uuid),
            sso_provider=sso_provider,
            otp=random_otp,
            user=normal_user,
            session=session,
        )
    display(e.value.detail)
    assert "MFA is not activated for the account" in str(e.value.detail), str(
        e.value.detail
    )

    # Negative Scenario: disabling SSO for others
    with pytest.raises(HTTPException) as e:
        disable_sso(
            user_uuid_or_name=INVALID_UUID_FOR_TESTING,
            sso_provider=sso_provider,
            otp=None,
            user=normal_user,
            session=session,
        )
    display(e.value.detail)
    assert "Insufficient permission" in str(e.value.detail), str(e.value.detail)

SSO()

SSO()

'SSO is already disabled for the provider.'

'MFA is not activated for the account. Please pass the OTP only after activating the MFA for your account.'

"Insufficient permission to access other user's data"

In [None]:
# | export


def create_trial_user(subscription_type: str, session: Session) -> User:
    """Create a trial user for the given subscription_type"""

    username = "".join(
        random.choice(string.ascii_lowercase) for _ in range(10)  # nosec B311
    )

    user_to_create = UserCreate(
        username=f"{subscription_type}_{username}",
        first_name=f"{subscription_type}_first_name",
        last_name=f"{subscription_type}_last_name",
        email=f"{subscription_type}_{username}@email.com",
        password=f"{subscription_type}_{username}",
        subscription_type=subscription_type,
    )
    return User._create(user_to_create, session)  # type: ignore

In [None]:
with get_session_with_context() as session:
    new_user = create_trial_user(subscription_type="captn_trial", session=session)
display(new_user)
assert "captn_trial_" in new_user.username, new_user.username

User(id=457, uuid=UUID('886946dc-0fe1-43a1-bb24-281485f161b2'), username='captn_trial_wzcwqeguyk', first_name='captn_trial_first_name', last_name='captn_trial_last_name', email='captn_trial_wzcwqeguyk@email.com', subscription_type=<SubscriptionType.captn_trial: 'captn_trial'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 55), phone_number=None, is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)

In [None]:
# | export


@user_router.get("/sso_signup")
def sso_signup(subscription_type: str, sso_provider: str) -> str:
    """Method to create new user with SSO"""
    with get_session_with_context() as session:
        # 1. Create Trial user
        trial_user = create_trial_user(
            subscription_type=subscription_type, session=session
        )
        # 2. Enable SSO for the Trial user
        enable_sso_request = EnableSSORequest(
            sso_provider=sso_provider, sso_email=trial_user.email
        )
        sso = enable_sso.__wrapped__(  # type: ignore
            enable_sso_request=enable_sso_request, user=trial_user, session=session
        )
        # 3. get authorization URL
        return initiate_sso_flow(
            username=trial_user.username,
            sso_provider=sso_provider,
            nonce=secrets.token_hex(),
            sso=sso,
        ).authorization_url

In [None]:
with get_session_with_context() as session:
    authorization_url = sso_signup(
        subscription_type="captn_trial", sso_provider="google"
    )
    assert "captn_trial" in authorization_url
    assert "google" in authorization_url
    print(authorization_url)

https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=222935564406-p6orfelpk34tsm03v5dm692v1gsb0qfo.apps.googleusercontent.com&redirect_uri=http%3A%2F%2F127.0.0.1%3A6006%2Fsso%2Fcallback&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email+openid&state=eb154a64a5a294329b46e610b11919ed086507ff995f32fa1c3f4a94d77b9479_captn_trial_ytcqpblxdg&prompt=select_account


In [None]:
# | export


class RegisterPhoneNumberRequest(BaseModel):
    """A base class for registering a new phone number

    Args:
        phone_number: User's new phone number to add in the db
        otp: Dynamically generated six-digit verification code from the authenticator app
    """

    phone_number: Optional[str] = None
    otp: Optional[str] = None

In [None]:
# | export


@user_router.post(
    "/register_phone_number",
    response_model=UserRead,
    responses={
        400: {"model": HTTPError, "description": ERRORS["NO_PHONE_NUMBER_TO_REGISTER"]},
    },
)
@require_otp_if_mfa_enabled
def register_phone_number(
    register_phone_number_request: RegisterPhoneNumberRequest,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> User:
    """Register a new phone number for the user"""
    user = session.merge(user)

    phone_number = register_phone_number_request.phone_number

    if phone_number is None:
        if user.phone_number is None:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERRORS["NO_PHONE_NUMBER_TO_REGISTER"],
            )
        phone_number = user.phone_number

    else:
        with commit_or_rollback(session):
            user.is_phone_number_verified = False
            user.phone_number = phone_number
            session.add(user)

    user = _send_sms_otp_to_user(
        user=user,
        message_template_name="register_phone_number",
        session=session,
        phone_number=phone_number,
    )

    return user

In [None]:
# Negative Scenario for MFA user: phone number not passed in arguments nor set in the db
with create_mfa_enabled_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]

    register_phone_number_request = RegisterPhoneNumberRequest()

    # Negative scenario: Not passing the OTP in the request
    with pytest.raises(HTTPException) as e:
        register_phone_number(
            register_phone_number_request=register_phone_number_request,
            user=user,
            session=session,
        )
    assert "OTP is required" in str(e.value.detail)
    display(e.value.detail)

    # Negative scenario: passing invalid OTP in the request
    with pytest.raises(HTTPException) as e:
        random_otp = 111111
        register_phone_number_request = RegisterPhoneNumberRequest(otp=random_otp)
        register_phone_number(
            register_phone_number_request=register_phone_number_request,
            user=user,
            session=session,
        )
    assert "Invalid OTP" in str(e.value.detail), str(e.value.detail)
    display(e.value.detail)

    with pytest.raises(HTTPException) as e:
        otp = pyotp.TOTP(user.mfa_secret).now()
        register_phone_number_request = RegisterPhoneNumberRequest(otp=otp)
        register_phone_number(
            register_phone_number_request=register_phone_number_request,
            user=user,
            session=session,
        )
    display(e.value.detail)
    assert "Please pass a phone number to register" in str(e.value.detail), str(
        e.value.detail
    )

'OTP is required. Please enter the OTP generated by the authenticator app or the one you requested via SMS.'

'Invalid OTP. Please try again.'

'Please pass a phone number to register.'

In [None]:
# Mocking Positive scenario for MFA: smsStatus is set to MESSAGE_SENT

with MonkeyPatch.context() as monkeypatch:
    sample_response = {
        "pinId": "my_random_pin_id",
        "to": "910000000000",
        "ncStatus": "NC_NOT_CONFIGURED",
        "smsStatus": "MESSAGE_SENT",
    }
    monkeypatch.setattr(
        "airt_service.sms_utils.send_sms", lambda x, y, z: sample_response
    )

    with create_mfa_enabled_user() as user_and_session:
        user = user_and_session[0]
        session = user_and_session[1]

        random_phone_number = "910000000000"
        register_phone_number_request = RegisterPhoneNumberRequest(
            phone_number=random_phone_number, otp=pyotp.TOTP(user.mfa_secret).now()
        )
        actual = register_phone_number(
            register_phone_number_request=register_phone_number_request,
            user=user,
            session=session,
        )
        display(actual)
        assert user.phone_number == random_phone_number
        assert user.is_phone_number_verified == False

        sms = session.exec(select(SMS).where(SMS.user == user)).one()

        display(sms)
        assert sms.user_id == user.id

        sms_protocol = session.exec(
            select(SMSProtocol).where(SMSProtocol.sms_id == sms.id)
        ).one()

        display(sms_protocol)
        assert sms_protocol.sms_id == sms.id

        # Triggering Multiple phone number register request and making sure the previous records in the sms_protocol are cleared
        register_phone_number_request = RegisterPhoneNumberRequest(
            phone_number=random_phone_number, otp=pyotp.TOTP(user.mfa_secret).now()
        )
        actual = register_phone_number(
            register_phone_number_request=register_phone_number_request,
            user=user,
            session=session,
        )
        display(actual)

        sms = session.exec(select(SMS).where(SMS.user == user)).one()

        display(sms)
        assert sms.user_id == user.id

        sms_protocol_1 = session.exec(
            select(SMSProtocol).where(SMSProtocol.sms_id == sms.id)
        ).one()

        display(sms_protocol_1)
        assert sms_protocol_1.sms_id == sms.id

        display(f"{sms_protocol.id=} != {sms_protocol_1.id=}")
        assert sms_protocol.id != sms_protocol_1.id

User(id=460, uuid=UUID('ac2f6892-5987-449d-826f-8114acf8f8cf'), username='jxrobsaduj', first_name='unittest', last_name='user', email='jxrobsaduj@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 56), phone_number='910000000000', is_phone_number_verified=False, mfa_secret=**********************************, is_mfa_active=True)

SMS(message_id='A0AB8E85E0FBD53F5A7F14121E484A4D', application_id='0636B234BE48C57A95BB07AA93A5483E', id=34, user_id=460)

SMSProtocol(id=34, number_lookup_status='NC_NOT_CONFIGURED', phone_number='910000000000', pin_attempts_remaining=None, pin_id='my_random_pin_id', sent_sms_status='MESSAGE_SENT', pin_verified=False, sms_id=34)

User(id=460, uuid=UUID('ac2f6892-5987-449d-826f-8114acf8f8cf'), username='jxrobsaduj', first_name='unittest', last_name='user', email='jxrobsaduj@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 53, 56), phone_number='910000000000', is_phone_number_verified=False, mfa_secret=**********************************, is_mfa_active=True)

SMS(message_id='A0AB8E85E0FBD53F5A7F14121E484A4D', application_id='0636B234BE48C57A95BB07AA93A5483E', id=34, user_id=460)

SMSProtocol(id=35, number_lookup_status='NC_NOT_CONFIGURED', phone_number='910000000000', pin_attempts_remaining=None, pin_id='my_random_pin_id', sent_sms_status='MESSAGE_SENT', pin_verified=False, sms_id=34)

'sms_protocol.id=34 != sms_protocol_1.id=35'

In [None]:
# Mocking Negative scenario for Non-MFA: smsStatus is set to MESSAGE_NOT_SENT

with MonkeyPatch.context() as monkeypatch:
    sample_response = {
        "pinId": "my_random_pin_id",
        "to": "910000000000",
        "ncStatus": "NC_NOT_CONFIGURED",
        "smsStatus": "MESSAGE_NOT_SENT",
    }
    monkeypatch.setattr(
        "airt_service.sms_utils.send_sms", lambda x, y, z: sample_response
    )

    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.username == test_username)).one()

        random_phone_number = "910000000000"
        register_phone_number_request = RegisterPhoneNumberRequest(
            phone_number=random_phone_number
        )
        with pytest.raises(HTTPException) as e:
            register_phone_number(
                register_phone_number_request=register_phone_number_request,
                user=user,
                session=session,
            )
        display(e.value.detail)
        assert "Failed to send OTP via SMS" in str(e.value.detail), str(e.value.detail)

'Failed to send OTP via SMS. Please check the phone number you have registered is valid and can receive SMS. Also, make sure the format of the phone number you have entered follows the pattern of country code followed by your phone number (without spaces). For example, 440123456789, +440123456789, and 00440123456789 are all valid formats for registering a UK phone number.'

In [None]:
# Mocking Negative scenario for Non-MFA: SMS request limit expired for the phone number

with MonkeyPatch.context() as monkeypatch:
    sample_response = {
        "requestError": {
            "serviceException": {
                "messageId": "THROTTLE_EXCEPTION",
                "text": "Too many requests. Try again later.",
            }
        }
    }
    monkeypatch.setattr(
        "airt_service.sms_utils.send_sms", lambda x, y, z: sample_response
    )

    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.username == test_username)).one()

        random_phone_number = "910000000000"
        register_phone_number_request = RegisterPhoneNumberRequest(
            phone_number=random_phone_number
        )
        with pytest.raises(HTTPException) as e:
            register_phone_number(
                register_phone_number_request=register_phone_number_request,
                user=user,
                session=session,
            )
        display(e.value.detail)
        assert "Too many requests. Try again later" in str(e.value.detail), str(
            e.value.detail
        )

'Too many requests. Try again later.'

In [None]:
# | export


@user_router.get(
    "/validate_phone_number",
    response_model=UserRead,
    responses={
        400: {"model": HTTPError, "description": ERRORS["PHONE_NUMBER_NOT_REGISTERED"]},
    },
)
def validate_phone_number(
    otp: str,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> User:
    """Validate user's phone number"""
    user = session.merge(user)

    if user.phone_number is None:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["PHONE_NUMBER_NOT_REGISTERED"],
        )

    validate_otp(
        user=user,
        otp=otp,
        message_template_name="register_phone_number",
        session=session,
    )

    with commit_or_rollback(session):
        user.is_phone_number_verified = True
        session.add(user)

    return user

In [None]:
test_username = create_user_for_testing()

with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    with pytest.raises(HTTPException) as e:
        invalid_otp = 123456
        validate_phone_number(otp=invalid_otp, user=user, session=session)
    display(e.value.detail)
    assert "The phone number is not yet registered" in str(e.value.detail), str(
        e.value.detail
    )

'The phone number is not yet registered. Please register your phone number before calling this method.'

In [None]:
# Mocking Positive scenario
test_username = create_user_for_testing()
with MonkeyPatch.context() as monkeypatch:
    send_sms_sample_response = {
        "pinId": "my_random_pin_id",
        "to": "910000000000",
        "ncStatus": "NC_NOT_CONFIGURED",
        "smsStatus": "MESSAGE_SENT",
    }
    monkeypatch.setattr(
        "airt_service.sms_utils.send_sms", lambda x, y, z: send_sms_sample_response
    )

    verify_pin_sample_response = {
        "pinId": "my_random_pin_id",
        "msisdn": "910000000000",
        "verified": True,
        "attemptsRemaining": 0,
    }

    monkeypatch.setattr(
        "airt_service.sms_utils.verify_pin", lambda x, y: verify_pin_sample_response
    )

    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.username == test_username)).one()
        # register phone number
        random_phone_number = "910000000000"
        register_phone_number_request = RegisterPhoneNumberRequest(
            phone_number=random_phone_number
        )
        actual = register_phone_number(
            register_phone_number_request=register_phone_number_request,
            user=user,
            session=session,
        )
        display(actual)
        assert user.phone_number == random_phone_number
        assert user.is_phone_number_verified == False

        # validate phone number
        randon_otp = 111111
        actual = validate_phone_number(otp=randon_otp, user=user, session=session)

        display(actual)
        assert user.phone_number == random_phone_number
        assert user.is_phone_number_verified == True

        sms = session.exec(select(SMS).where(SMS.user == user)).one()

        display(sms)
        assert sms.user_id == user.id

        with pytest.raises(NoResultFound) as e:
            sms_protocol = session.exec(
                select(SMSProtocol).where(SMSProtocol.sms_id == sms.id)
            ).one()
        display(e.value)

User(id=462, uuid=UUID('0e38ef46-ff5a-4d33-9d7e-ed5c24702739'), username='iapfvuxcul', first_name='unittest', last_name='user', email='iapfvuxcul@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 54, 10), phone_number='910000000000', is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)

User(id=462, uuid=UUID('0e38ef46-ff5a-4d33-9d7e-ed5c24702739'), username='iapfvuxcul', first_name='unittest', last_name='user', email='iapfvuxcul@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 54, 10), phone_number='910000000000', is_phone_number_verified=True, mfa_secret=****, is_mfa_active=False)

SMS(message_id='A0AB8E85E0FBD53F5A7F14121E484A4D', application_id='0636B234BE48C57A95BB07AA93A5483E', id=36, user_id=462)

sqlalchemy.exc.NoResultFound('No row was found when one was required')

In [None]:
# Mocking Negative scenario: wrong PIN with attemptsRemaining = 0
test_username = create_user_for_testing()
with MonkeyPatch.context() as monkeypatch:
    send_sms_sample_response = {
        "pinId": "my_random_pin_id",
        "to": "910000000000",
        "ncStatus": "NC_NOT_CONFIGURED",
        "smsStatus": "MESSAGE_SENT",
    }
    monkeypatch.setattr(
        "airt_service.sms_utils.send_sms", lambda x, y, z: send_sms_sample_response
    )

    verify_pin_sample_response = {
        "pinId": "my_random_pin_id",
        "msisdn": "910000000000",
        "verified": False,
        "attemptsRemaining": 0,
        "pinError": "WRONG_PIN",
    }

    monkeypatch.setattr(
        "airt_service.sms_utils.verify_pin", lambda x, y: verify_pin_sample_response
    )

    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.username == test_username)).one()

        random_phone_number = "910000000000"
        register_phone_number_request = RegisterPhoneNumberRequest(
            phone_number=random_phone_number
        )
        actual = register_phone_number(
            register_phone_number_request=register_phone_number_request,
            user=user,
            session=session,
        )
        display(actual)
        assert user.phone_number == random_phone_number
        assert user.is_phone_number_verified == False

        with pytest.raises(HTTPException) as e:
            randon_otp = 111111
            validate_phone_number(otp=randon_otp, user=user, session=session)

        display(e.value.detail)
        assert "Incorrect OTP" in str(e.value.detail), str(e.value.detail)

        with pytest.raises(HTTPException) as e:
            randon_otp = 111111
            validate_phone_number(otp=randon_otp, user=user, session=session)

        display(e.value.detail)
        assert "Too many failed attempts" in str(e.value.detail), str(e.value.detail)

        sms = session.exec(select(SMS).where(SMS.user == user)).one()

        display(sms)
        assert sms.user_id == user.id

        sms_protocol = session.exec(
            select(SMSProtocol).where(SMSProtocol.sms_id == sms.id)
        ).one()

        display(sms_protocol)
        assert sms_protocol.sms_id == sms.id

User(id=463, uuid=UUID('d16cc2a2-1c61-423a-88ed-904ad14ae4c4'), username='jqpxxbisey', first_name='unittest', last_name='user', email='jqpxxbisey@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 54, 18), phone_number='910000000000', is_phone_number_verified=False, mfa_secret=****, is_mfa_active=False)

'Incorrect OTP. Please enter the OTP you have received on your registered phone number and try again.'

'Too many failed attempts. Please initiate the phone registration process again.'

SMS(message_id='A0AB8E85E0FBD53F5A7F14121E484A4D', application_id='0636B234BE48C57A95BB07AA93A5483E', id=37, user_id=463)

SMSProtocol(id=37, number_lookup_status='NC_NOT_CONFIGURED', phone_number='910000000000', pin_attempts_remaining=0, pin_id='my_random_pin_id', sent_sms_status='MESSAGE_SENT', pin_verified=False, sms_id=37)

In [None]:
# | export


class ResetPasswordRequest(BaseModel):
    """Request object to reset user's password

    Args:
        username: Username to reset the password
        new_password: New password to set for the user's account
        otp: Dynamically generated six-digit verification code from the authenticator app or the OTP received via SMS
    """

    username: str
    new_password: str
    otp: str

In [None]:
# | export


def require_otp_or_totp(message_template_name: str) -> Callable[..., Any]:
    """A decorator function to validate the totp/otp

    If the totp/otp validation fails, the user will not be granted access to the decorated route

    Args:
        message_template_name: Name of the message template that was used to send the SMS
    """

    def outer_wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
        @functools.wraps(func)
        def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
            username = get_attr_by_name(kwargs, "username")
            otp_or_totp = get_attr_by_name(kwargs, "otp")
            session = kwargs["session"]

            user = get_user(username)  # type: ignore
            if user is None:
                raise HTTPException(
                    status_code=status.HTTP_401_UNAUTHORIZED,
                    detail=ERRORS["INCORRECT_USERNAME_OR_OTP"],
                )

            if user.is_mfa_active:
                try:
                    validate_totp(user.mfa_secret, otp_or_totp)  # type: ignore
                    return func(*args, **kwargs)
                except HTTPException as e:
                    pass
            try:
                validate_otp(
                    user=user,
                    otp=otp_or_totp,  # type: ignore
                    message_template_name=message_template_name,
                    session=session,
                )
            except HTTPException as e:
                raise HTTPException(
                    status_code=status.HTTP_401_UNAUTHORIZED,
                    detail=ERRORS["INCORRECT_USERNAME_OR_OTP"],
                )

            # Do something before
            return func(*args, **kwargs)
            # Do something after

        return inner_wrapper

    return outer_wrapper

In [None]:
reset_password_request = ResetPasswordRequest(
    username="this_is_an_invalid_username", new_password="random_password", otp="000000"
)


@require_otp_or_totp(message_template_name="reset_password")
def test_require_otp_or_totp(
    reset_password_request,
    session,
):
    return "Ok"


with get_session_with_context() as session:
    with pytest.raises(HTTPException) as e:
        test_require_otp_or_totp(
            reset_password_request=reset_password_request, session=session
        )
    display(e.value.detail)

'Something went wrong. The username or OTP you entered is incorrect. Please try again or contact your administrator.'

In [None]:
test_username = create_user_for_testing()
random_phone_number = "910000000000"
random_sms_pin_id = "my_random_pin_id"
message_template_name = "reset_password"
random_sms_otp = "111111"
with MonkeyPatch.context() as monkeypatch:
    send_sms_sample_response = {
        "pinId": random_sms_pin_id,
        "to": random_phone_number,
        "ncStatus": "NC_NOT_CONFIGURED",
        "smsStatus": "MESSAGE_SENT",
    }
    monkeypatch.setattr(
        "airt_service.sms_utils.send_sms", lambda x, y, z: send_sms_sample_response
    )

    verify_pin_sample_response = {
        "pinId": random_sms_pin_id,
        "msisdn": random_phone_number,
        "verified": True,
        "attemptsRemaining": 0,
    }

    monkeypatch.setattr(
        "airt_service.sms_utils.verify_pin", lambda x, y: verify_pin_sample_response
    )

    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.username == test_username)).one()
        user.phone_number = random_phone_number
        user.is_phone_number_verified = True
        session.add(user)
        session.commit()
        session.refresh(user)

        actual = generate_mfa_url(user=user, session=session)
        assert user.mfa_secret is not None
        # activate MFA
        activate_mfa_request = ActivateMFARequest(
            user_otp=pyotp.TOTP(user.mfa_secret).now()
        )
        actual = activate_mfa(
            activate_mfa_request=activate_mfa_request, user=user, session=session
        )
        assert actual.is_mfa_active
        display(actual)

    actual = send_sms_otp(
        username=test_username,
        message_template_name=message_template_name,
        session=session,
    )
    display(actual)

    reset_password_request = ResetPasswordRequest(
        username=test_username, new_password="random_password", otp="000000"
    )
    actual = test_require_otp_or_totp(
        reset_password_request=reset_password_request, session=session
    )
    display(actual)

User(id=464, uuid=UUID('ed4b1876-cb7a-432b-a939-ad62a7a12d07'), username='qkdfggzzjh', first_name='unittest', last_name='user', email='qkdfggzzjh@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 54, 28), phone_number='910000000000', is_phone_number_verified=True, mfa_secret=**********************************, is_mfa_active=True)

'If you have already registered and verified your phone number, you will receive the OTP by SMS. If you did not receive the OTP, please contact your administrator.'

'Ok'

In [None]:
test_username = create_user_for_testing()
random_phone_number = "910000000000"
random_sms_pin_id = "my_random_pin_id"
message_template_name = "reset_password"
random_sms_otp = "111111"
with MonkeyPatch.context() as monkeypatch:
    send_sms_sample_response = {
        "pinId": random_sms_pin_id,
        "to": random_phone_number,
        "ncStatus": "NC_NOT_CONFIGURED",
        "smsStatus": "MESSAGE_SENT",
    }
    monkeypatch.setattr(
        "airt_service.sms_utils.send_sms", lambda x, y, z: send_sms_sample_response
    )

    verify_pin_sample_response = {
        "pinId": random_sms_pin_id,
        "msisdn": random_phone_number,
        "verified": False,
        "attemptsRemaining": 0,
        "pinError": "WRONG_PIN",
    }

    monkeypatch.setattr(
        "airt_service.sms_utils.verify_pin", lambda x, y: verify_pin_sample_response
    )

    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.username == test_username)).one()
        user.phone_number = random_phone_number
        user.is_phone_number_verified = True
        session.add(user)
        session.commit()
        session.refresh(user)

        actual = generate_mfa_url(user=user, session=session)
        assert user.mfa_secret is not None
        # activate MFA
        activate_mfa_request = ActivateMFARequest(
            user_otp=pyotp.TOTP(user.mfa_secret).now()
        )
        actual = activate_mfa(
            activate_mfa_request=activate_mfa_request, user=user, session=session
        )
        assert actual.is_mfa_active
        display(actual)

    actual = send_sms_otp(
        username=test_username,
        message_template_name=message_template_name,
        session=session,
    )
    display(actual)

    reset_password_request = ResetPasswordRequest(
        username=test_username, new_password="random_password", otp="000000"
    )
    with pytest.raises(HTTPException) as e:
        test_require_otp_or_totp(
            reset_password_request=reset_password_request, session=session
        )
    display(e.value.detail)

User(id=465, uuid=UUID('4c9b0913-f86a-4407-9b1a-8b173ff3a2cb'), username='igrhssbdmg', first_name='unittest', last_name='user', email='igrhssbdmg@email.com', subscription_type=<SubscriptionType.test: 'test'>, super_user=False, disabled=False, created=datetime.datetime(2023, 3, 10, 7, 54, 35), phone_number='910000000000', is_phone_number_verified=True, mfa_secret=**********************************, is_mfa_active=True)

'If you have already registered and verified your phone number, you will receive the OTP by SMS. If you did not receive the OTP, please contact your administrator.'

'Something went wrong. The username or OTP you entered is incorrect. Please try again or contact your administrator.'

In [None]:
# | export


@user_router.post(
    "/reset_password",
    responses={
        401: {"model": HTTPError, "description": ERRORS["INCORRECT_USERNAME_OR_OTP"]},
    },
)
@require_otp_or_totp(message_template_name="reset_password")
def reset_password(
    reset_password_request: ResetPasswordRequest,
    session: Session = Depends(get_session),
) -> str:
    """Reset passowrd for the user"""
    username = reset_password_request.username
    new_password = reset_password_request.new_password

    user = get_user(username)
    user = session.merge(user)

    with commit_or_rollback(session):
        user.password = get_password_hash(new_password)  # type: ignore
        session.add(user)

    return PASSWORD_RESET_MSG

In [None]:
# Tests for reset_password:
# Negative Scenario: passing invalid username

reset_password_request = ResetPasswordRequest(
    username="this_is_an_invalid_username", new_password="random_password", otp="000000"
)

with pytest.raises(HTTPException) as e:
    reset_password(reset_password_request=reset_password_request, session=session)


assert "Something went wrong" in str(e.value.detail), str(e.value.detail)
display(e.value.detail)

'Something went wrong. The username or OTP you entered is incorrect. Please try again or contact your administrator.'

In [None]:
# Tests for reset_password:
# Negative Scenario: Non-MFA user with no phone number registered trying to reset_password
test_username = create_user_for_testing()
message_template_name = "password_reset"
actual = send_sms_otp(
    username=test_username, message_template_name=message_template_name, session=session
)
display(actual)

reset_password_request = ResetPasswordRequest(
    username=test_username, new_password="random_password", otp="000000"
)

with pytest.raises(HTTPException) as e:
    reset_password(reset_password_request=reset_password_request, session=session)


assert "Something went wrong" in str(e.value.detail), str(e.value.detail)
display(f"\n\n{e.value.detail}")

'If you have already registered and verified your phone number, you will receive the OTP by SMS. If you did not receive the OTP, please contact your administrator.'

'\n\nSomething went wrong. The username or OTP you entered is incorrect. Please try again or contact your administrator.'

In [None]:
# Tests for reset_password:
# Negative Scenario: Non-MFA user with no phone number not validated trying to reset_password

random_phone_number = "910000000000"
message_template_name = "password_reset"
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    user.phone_number = random_phone_number
    session.add(user)
    session.commit()

    actual = send_sms_otp(
        username=test_username,
        message_template_name=message_template_name,
        session=session,
    )
    display(actual)

    reset_password_request = ResetPasswordRequest(
        username=test_username, new_password="random_password", otp="000000"
    )

    with pytest.raises(HTTPException) as e:
        reset_password(reset_password_request=reset_password_request, session=session)


assert "Something went wrong" in str(e.value.detail), str(e.value.detail)
display(f"\n\n{e.value.detail}")

'If you have already registered and verified your phone number, you will receive the OTP by SMS. If you did not receive the OTP, please contact your administrator.'

'\n\nSomething went wrong. The username or OTP you entered is incorrect. Please try again or contact your administrator.'

In [None]:
# Tests for reset_password:
# Negative Scenario: Non-MFA user with phone number validated passing invalid OTP

test_username = create_user_for_testing()
random_phone_number = "910000000000"
random_sms_pin_id = "my_random_pin_id"
message_template_name = "password_reset"
with MonkeyPatch.context() as monkeypatch:
    send_sms_sample_response = {
        "pinId": random_sms_pin_id,
        "to": random_phone_number,
        "ncStatus": "NC_NOT_CONFIGURED",
        "smsStatus": "MESSAGE_SENT",
    }
    monkeypatch.setattr(
        "airt_service.sms_utils.send_sms", lambda x, y, z: send_sms_sample_response
    )

    verify_pin_sample_response = {
        "pinId": random_sms_pin_id,
        "msisdn": random_phone_number,
        "verified": False,
        "attemptsRemaining": 0,
        "pinError": "WRONG_PIN",
    }

    monkeypatch.setattr(
        "airt_service.sms_utils.verify_pin", lambda x, y: verify_pin_sample_response
    )

    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.username == test_username)).one()
        user.phone_number = random_phone_number
        user.is_phone_number_verified = True
        session.add(user)
        session.commit()
        session.refresh(user)

    actual = send_sms_otp(
        username=test_username,
        message_template_name=message_template_name,
        session=session,
    )
    display(actual)

    reset_password_request = ResetPasswordRequest(
        username=test_username, new_password="random_password", otp="000000"
    )

    with pytest.raises(HTTPException) as e:
        reset_password(reset_password_request=reset_password_request, session=session)

assert "Something went wrong" in str(e.value.detail), str(e.value.detail)
display(f"\n\n{e.value.detail}")

'If you have already registered and verified your phone number, you will receive the OTP by SMS. If you did not receive the OTP, please contact your administrator.'

'\n\nSomething went wrong. The username or OTP you entered is incorrect. Please try again or contact your administrator.'

In [None]:
# Tests for reset_password:
# Negative Scenario: Non-MFA user with phone number validated passing invalid message_template_name

test_username = create_user_for_testing()
random_phone_number = "910000000000"
random_sms_pin_id = "my_random_pin_id"
new_password = "random_password"
message_template_name = "invalid_message_template_name"
with MonkeyPatch.context() as monkeypatch:
    send_sms_sample_response = {
        "pinId": random_sms_pin_id,
        "to": random_phone_number,
        "ncStatus": "NC_NOT_CONFIGURED",
        "smsStatus": "MESSAGE_SENT",
    }
    monkeypatch.setattr(
        "airt_service.sms_utils.send_sms", lambda x, y, z: send_sms_sample_response
    )

    verify_pin_sample_response = {
        "pinId": random_sms_pin_id,
        "msisdn": random_phone_number,
        "verified": True,
        "attemptsRemaining": 0,
    }

    monkeypatch.setattr(
        "airt_service.sms_utils.verify_pin", lambda x, y: verify_pin_sample_response
    )

    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.username == test_username)).one()
        user.phone_number = random_phone_number
        user.is_phone_number_verified = True
        session.add(user)
        session.commit()
        session.refresh(user)

    actual = send_sms_otp(
        username=test_username,
        message_template_name=message_template_name,
        session=session,
    )
    display(actual)

    reset_password_request = ResetPasswordRequest(
        username=test_username, new_password=new_password, otp="000000"
    )
    with pytest.raises(HTTPException) as e:
        reset_password(reset_password_request=reset_password_request, session=session)

    assert "Something went wrong" in str(e.value.detail), str(e.value.detail)
    display(f"\n\n{e.value.detail}")

    with get_session_with_context() as session:
        with pytest.raises(NoResultFound) as e:
            sms = session.exec(select(SMS).where(SMS.user == user)).one()

'If you have already registered and verified your phone number, you will receive the OTP by SMS. If you did not receive the OTP, please contact your administrator.'

'\n\nSomething went wrong. The username or OTP you entered is incorrect. Please try again or contact your administrator.'

In [None]:
# Tests for reset_password:
# Positive Scenario: Non-MFA user with phone number validated passing valid OTP

test_username = create_user_for_testing()
random_phone_number = "910000000000"
random_sms_pin_id = "my_random_pin_id"
new_password = "random_password"
message_template_name = "reset_password"
with MonkeyPatch.context() as monkeypatch:
    send_sms_sample_response = {
        "pinId": random_sms_pin_id,
        "to": random_phone_number,
        "ncStatus": "NC_NOT_CONFIGURED",
        "smsStatus": "MESSAGE_SENT",
    }
    monkeypatch.setattr(
        "airt_service.sms_utils.send_sms", lambda x, y, z: send_sms_sample_response
    )

    verify_pin_sample_response = {
        "pinId": random_sms_pin_id,
        "msisdn": random_phone_number,
        "verified": True,
        "attemptsRemaining": 0,
    }

    monkeypatch.setattr(
        "airt_service.sms_utils.verify_pin", lambda x, y: verify_pin_sample_response
    )

    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.username == test_username)).one()
        user.phone_number = random_phone_number
        user.is_phone_number_verified = True
        session.add(user)
        session.commit()
        session.refresh(user)

    actual = send_sms_otp(
        username=test_username,
        message_template_name=message_template_name,
        session=session,
    )
    display(actual)

    reset_password_request = ResetPasswordRequest(
        username=test_username, new_password=new_password, otp="000000"
    )

    actual = reset_password(
        reset_password_request=reset_password_request, session=session
    )
    display(f"\n\n{actual}")
    assert "Password reset successful" in actual

    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.username == test_username)).one()
        display(f"{verify_password(new_password, user.password)=}")
        assert verify_password(new_password, user.password)

        sms = session.exec(select(SMS).where(SMS.user == user)).one()

        display(sms)
        assert sms.user_id == user.id

        with pytest.raises(NoResultFound) as e:
            sms_protocol = session.exec(
                select(SMSProtocol).where(SMSProtocol.sms_id == sms.id)
            ).one()
        display(e.value)

'If you have already registered and verified your phone number, you will receive the OTP by SMS. If you did not receive the OTP, please contact your administrator.'

'\n\nPassword reset successful'

'verify_password(new_password, user.password)=True'

SMS(message_id='74444CB4E552B0BF5E57DC305B49DD64', application_id='0636B234BE48C57A95BB07AA93A5483E', id=40, user_id=469)

sqlalchemy.exc.NoResultFound('No row was found when one was required')

In [None]:
# Tests for reset_password:
# Negative Scenario: MFA user without phone number passing invalid otp
with create_mfa_enabled_user() as user_and_session:
    user, session = user_and_session

    reset_password_request = ResetPasswordRequest(
        username=user.username, new_password="random_password", otp="000000"
    )

    with pytest.raises(HTTPException) as e:
        reset_password(reset_password_request=reset_password_request, session=session)


assert "Something went wrong" in str(e.value.detail), str(e.value.detail)
display(f"\n\n{e.value.detail}")

'\n\nSomething went wrong. The username or OTP you entered is incorrect. Please try again or contact your administrator.'

In [None]:
# Tests for reset_password:
# Positive Scenario: MFA user without phone number passing valid totp

new_password = "random_password"
message_template_name = "reset_password"
with create_mfa_enabled_user() as user_and_session:
    user, session = user_and_session

    actual = send_sms_otp(
        username=user.username,
        message_template_name=message_template_name,
        session=session,
    )
    display(actual)

    reset_password_request = ResetPasswordRequest(
        username=user.username,
        new_password=new_password,
        otp=pyotp.TOTP(user.mfa_secret).now(),
    )
    actual = reset_password(
        reset_password_request=reset_password_request, session=session
    )
    display(f"\n\n{actual}")
    assert "Password reset successful" in actual

with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == user.username)).one()

    display(f"{verify_password(new_password, user.password)=}")
    assert verify_password(new_password, user.password)

    with pytest.raises(NoResultFound) as e:
        session.exec(select(SMS).where(SMS.user == user)).one()
    display(e.value)

'If you have already registered and verified your phone number, you will receive the OTP by SMS. If you did not receive the OTP, please contact your administrator.'

'\n\nPassword reset successful'

'verify_password(new_password, user.password)=True'

sqlalchemy.exc.NoResultFound('No row was found when one was required')

In [None]:
# Tests for reset_password:
# Positive Scenario: MFA user with phone number calling reset_password directly with a valid totp


random_phone_number = "910000000000"
new_password = "random_password"
with create_mfa_enabled_user() as user_and_session:
    user, session = user_and_session

    user.phone_number = random_phone_number
    user.is_phone_number_verified = True
    session.add(user)
    session.commit()
    session.refresh(user)
    display(f"{user.phone_number=}")
    assert user.phone_number
    assert user.is_phone_number_verified

    reset_password_request = ResetPasswordRequest(
        username=user.username,
        new_password=new_password,
        otp=pyotp.TOTP(user.mfa_secret).now(),
    )
    actual = reset_password(
        reset_password_request=reset_password_request, session=session
    )
    display(f"\n\n{actual}")
    assert "Password reset successful" in actual

    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.username == user.username)).one()
        display(f"{verify_password(new_password, user.password)=}")
        assert verify_password(new_password, user.password)
        with pytest.raises(NoResultFound) as e:
            sms = session.exec(select(SMS).where(SMS.user == user)).one()
        display(e.value)

"user.phone_number='910000000000'"

'\n\nPassword reset successful'

'verify_password(new_password, user.password)=True'

sqlalchemy.exc.NoResultFound('No row was found when one was required')

In [None]:
# Tests for reset_password:
# Positive Scenario: MFA user with phone number passing valid totp (First calls send_sms_otp, then calls reset_password with valid totp)

random_sms_pin_id = "00000000000000"
random_phone_number = "910000000000"
new_password = "random_password"
message_template_name = "reset_password"
with MonkeyPatch.context() as monkeypatch:
    send_sms_sample_response = {
        "pinId": random_sms_pin_id,
        "to": random_phone_number,
        "ncStatus": "NC_NOT_CONFIGURED",
        "smsStatus": "MESSAGE_SENT",
    }
    monkeypatch.setattr(
        "airt_service.sms_utils.send_sms", lambda x, y, z: send_sms_sample_response
    )

    with create_mfa_enabled_user() as user_and_session:
        user, session = user_and_session

        user.phone_number = random_phone_number
        user.is_phone_number_verified = True
        session.add(user)
        session.commit()
        session.refresh(user)
        display(f"{user.phone_number=}")
        assert user.phone_number
        assert user.is_phone_number_verified

        actual = send_sms_otp(
            username=user.username,
            message_template_name=message_template_name,
            session=session,
        )
        display(actual)

        reset_password_request = ResetPasswordRequest(
            username=user.username,
            new_password=new_password,
            otp=pyotp.TOTP(user.mfa_secret).now(),
        )
        actual = reset_password(
            reset_password_request=reset_password_request, session=session
        )
        display(f"\n\n{actual}")
        assert "Password reset successful" in actual

    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.username == user.username)).one()
        display(f"{verify_password(new_password, user.password)=}")
        assert verify_password(new_password, user.password)
        sms = session.exec(select(SMS).where(SMS.user == user)).one()
        display(sms)

"user.phone_number='910000000000'"

'If you have already registered and verified your phone number, you will receive the OTP by SMS. If you did not receive the OTP, please contact your administrator.'

'\n\nPassword reset successful'

'verify_password(new_password, user.password)=True'

SMS(message_id='74444CB4E552B0BF5E57DC305B49DD64', application_id='0636B234BE48C57A95BB07AA93A5483E', id=41, user_id=473)

In [None]:
# Tests for reset_password:
# Positive Scenario: MFA user with phone number passing valid sms otp (First calls send_sms_otp, then calls reset_password with valid sms otp)

random_sms_pin_id = "00000000000000"
random_phone_number = "910000000000"
new_password = "random_password"
invalid_totp = "000000"
message_template_name = "reset_password"
with MonkeyPatch.context() as monkeypatch:
    send_sms_sample_response = {
        "pinId": random_sms_pin_id,
        "to": random_phone_number,
        "ncStatus": "NC_NOT_CONFIGURED",
        "smsStatus": "MESSAGE_SENT",
    }
    monkeypatch.setattr(
        "airt_service.sms_utils.send_sms", lambda x, y, z: send_sms_sample_response
    )

    verify_pin_sample_response = {
        "pinId": random_sms_pin_id,
        "msisdn": random_phone_number,
        "verified": True,
        "attemptsRemaining": 0,
    }

    monkeypatch.setattr(
        "airt_service.sms_utils.verify_pin", lambda x, y: verify_pin_sample_response
    )

    with create_mfa_enabled_user() as user_and_session:
        user, session = user_and_session

        user.phone_number = random_phone_number
        user.is_phone_number_verified = True
        session.add(user)
        session.commit()
        session.refresh(user)
        display(f"{user.phone_number=}")
        assert user.phone_number
        assert user.is_phone_number_verified

        actual = send_sms_otp(
            username=user.username,
            message_template_name=message_template_name,
            session=session,
        )
        display(actual)

        reset_password_request = ResetPasswordRequest(
            username=user.username, new_password=new_password, otp=invalid_totp
        )
        actual = reset_password(
            reset_password_request=reset_password_request, session=session
        )
        display(f"\n\n{actual}")
        assert "Password reset successful" in actual

        with get_session_with_context() as session:
            user = session.exec(
                select(User).where(User.username == user.username)
            ).one()
            display(f"{verify_password(new_password, user.password)=}")
            assert verify_password(new_password, user.password)
            sms = session.exec(select(SMS).where(SMS.user == user)).one()
            display(sms)
            assert sms.user_id == user.id

            with pytest.raises(NoResultFound) as e:
                sms_protocol = session.exec(
                    select(SMSProtocol).where(SMSProtocol.sms_id == sms.id)
                ).one()
            display(e.value)

"user.phone_number='910000000000'"

'If you have already registered and verified your phone number, you will receive the OTP by SMS. If you did not receive the OTP, please contact your administrator.'

'\n\nPassword reset successful'

'verify_password(new_password, user.password)=True'

SMS(message_id='74444CB4E552B0BF5E57DC305B49DD64', application_id='0636B234BE48C57A95BB07AA93A5483E', id=42, user_id=474)

sqlalchemy.exc.NoResultFound('No row was found when one was required')

In [None]:
# Tests for reset_password:
# Negative Scenario: MFA user with phone number passing invalid sms otp (First calls send_sms_otp, then calls reset_password with invalid sms otp)

random_sms_pin_id = "00000000000000"
random_phone_number = "910000000000"
new_password = "random_password"
invalid_totp = "000000"
message_template_name = "reset_password"
with MonkeyPatch.context() as monkeypatch:
    send_sms_sample_response = {
        "pinId": random_sms_pin_id,
        "to": random_phone_number,
        "ncStatus": "NC_NOT_CONFIGURED",
        "smsStatus": "MESSAGE_SENT",
    }
    monkeypatch.setattr(
        "airt_service.sms_utils.send_sms", lambda x, y, z: send_sms_sample_response
    )

    verify_pin_sample_response = {
        "pinId": random_sms_pin_id,
        "msisdn": random_phone_number,
        "verified": False,
        "attemptsRemaining": 0,
        "pinError": "WRONG_PIN",
    }

    monkeypatch.setattr(
        "airt_service.sms_utils.verify_pin", lambda x, y: verify_pin_sample_response
    )

    with create_mfa_enabled_user() as user_and_session:
        user, session = user_and_session

        user.phone_number = random_phone_number
        user.is_phone_number_verified = True
        session.add(user)
        session.commit()
        session.refresh(user)
        display(f"{user.phone_number=}")
        assert user.phone_number
        assert user.is_phone_number_verified

        actual = send_sms_otp(
            username=user.username,
            message_template_name=message_template_name,
            session=session,
        )
        display(actual)

        reset_password_request = ResetPasswordRequest(
            username=user.username, new_password=new_password, otp=invalid_totp
        )
        with pytest.raises(HTTPException) as e:
            reset_password(
                reset_password_request=reset_password_request, session=session
            )

        assert "Something went wrong" in str(e.value.detail), str(e.value.detail)
        display(f"\n\n{e.value.detail}")

        with get_session_with_context() as session:
            user = session.exec(
                select(User).where(User.username == user.username)
            ).one()
            display(f"{verify_password(new_password, user.password)=}")
            assert not verify_password(new_password, user.password)

"user.phone_number='910000000000'"

'If you have already registered and verified your phone number, you will receive the OTP by SMS. If you did not receive the OTP, please contact your administrator.'

'\n\nSomething went wrong. The username or OTP you entered is incorrect. Please try again or contact your administrator.'

'verify_password(new_password, user.password)=False'

In [None]:
# | export


class UserCleanupRequest(BaseModel):
    """Request object to cleanup user

    Args:
        username: username
        otp: Dynamically generated six-digit verification code from the authenticator app
    """

    username: str
    otp: Optional[str] = None

In [None]:
# | export


@user_router.post(
    "/cleanup",
    response_model=UserRead,
    responses={
        400: {"model": HTTPError, "description": ERRORS["CANT_CLEANUP_SELF"]},
        401: {"model": HTTPError, "description": ERRORS["NOT_ENOUGH_PERMISSION"]},
    },
)
@require_otp_if_mfa_enabled
@ensure_super_user
def cleanup(
    user_to_cleanup: UserCleanupRequest,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> User:
    """
    Cleanup user
    """
    user = session.merge(user)

    if user.username == user_to_cleanup.username:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["CANT_CLEANUP_SELF"],
        )

    try:
        user_to_cleanup_obj = session.exec(
            select(User).where(User.username == user_to_cleanup.username)
        ).one()
    except NoResultFound:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["INCORRECT_USERNAME"],
        )
    except Exception as e:
        logger.exception(e)
        error_message = (
            e._message() if callable(getattr(e, "_message", None)) else str(e)  # type: ignore
        )
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=error_message,
        )

    cleanup_user(user_to_cleanup_obj, session)

    return user

In [None]:
# MFA enabled user trying to delete user profile
with create_mfa_enabled_super_user() as user_and_session:
    user = user_and_session[0]
    session = user_and_session[1]
    # Negative scenario: Not passing the OTP in the request
    with pytest.raises(HTTPException) as e:
        cleanup(
            UserCleanupRequest(username=user.username),
            user=user,
            session=session,
        )
    assert "OTP is required" in str(e.value.detail), str(e.value.detail)
    display(e.value.detail)

    # Negative scenario: passing invalid OTP in the request
    with pytest.raises(HTTPException) as e:
        cleanup(
            user_to_cleanup=UserCleanupRequest(username=user.username, otp="123123"),
            user=user,
            session=session,
        )
    assert "Invalid OTP" in str(e.value.detail), str(e.value.detail)
    display(e.value.detail)

    # For following test cases
    test_user_username = user.username

'OTP is required. Please enter the OTP generated by the authenticator app or the one you requested via SMS.'

'Invalid OTP. Please try again.'

In [None]:
# Negative Scenario: MFA enabled user trying to cleanup by passing OTP
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    with pytest.raises(HTTPException) as e:
        user_kumaran = session.exec(
            select(User).where(User.username == "kumaran")
        ).one()
        random_otp = 111111
        cleanup(
            user_to_cleanup=UserCleanupRequest(
                username=test_user_username, otp=random_otp
            ),
            user=user_kumaran,
            session=session,
        )
    assert (
        str(e.value.detail)
        == "MFA is not activated for the account. Please pass the OTP only after activating the MFA for your account."
    )
    display(e.value.detail)

'MFA is not activated for the account. Please pass the OTP only after activating the MFA for your account.'

In [None]:
non_admin_username = create_user_for_testing(subscription_type="small")
with get_session_with_context() as session:
    user_kumaran = session.exec(select(User).where(User.username == "kumaran")).one()
    non_admin_user = session.exec(
        select(User).where(User.username == non_admin_username)
    ).one()

    with pytest.raises(HTTPException) as e:
        cleanup(
            UserCleanupRequest(username=user_kumaran.username),
            user=user_kumaran,
            session=session,
        )
    display(e)

    with pytest.raises(HTTPException) as e:
        cleanup(
            UserCleanupRequest(username=user_kumaran.username),
            user=non_admin_user,
            session=session,
        )
    display(e)

<ExceptionInfo HTTPException(status_code=400, detail='Users cannot delete their own account.') tblen=4>

<ExceptionInfo HTTPException(status_code=401, detail='You do not have sufficient permission to access this route. Please contact your administrator for help.') tblen=3>

In [None]:
to_delete_username = create_user_for_testing(subscription_type="small")
display(to_delete_username)

with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == to_delete_username)).one()

    create_apikey(
        apikey_to_create=APIKeyCreate(expiry=datetime.utcnow() + timedelta(days=1)),
        user=user,
        session=session,
    )

    from_local_request = FromLocalRequest(
        path="tmp/test-folder/", tag="my_csv_datasource_tag"
    )
    from_local_response = from_local_start_route(
        from_local_request=from_local_request,
        user=user,
        session=session,
    )

    with RemotePath.from_url(
        remote_url=f"s3://test-airt-service/account_312571_events",
        pull_on_enter=True,
        push_on_exit=False,
        exist_ok=True,
        parents=False,
        access_key=environ["AWS_ACCESS_KEY_ID"],
        secret_key=environ["AWS_SECRET_ACCESS_KEY"],
    ) as test_s3_path:
        df = pd.read_parquet(test_s3_path.as_path())
        display(df.head())
        df.to_csv(test_s3_path.as_path() / "file.csv", index=False)
        display(list(test_s3_path.as_path().glob("*")))
        #         !head -n 10 {test_s3_path.as_path()/"file.csv"}

        upload_to_s3_with_retry(
            test_s3_path.as_path() / "file.csv",
            from_local_response.presigned["url"],
            from_local_response.presigned["fields"],
        )

    datablob_id = (
        session.exec(select(DataBlob).where(DataBlob.uuid == from_local_response.uuid))
        .one()
        .id
    )
    datasource = DataSource(
        datablob_id=datablob_id,
        cloud_provider="aws",
        region="eu-west-1",
        total_steps=1,
        user=user,
    )
    session.add(datasource)
    session.commit()

    process_csv(
        datablob_id=datablob_id,
        datasource_id=datasource.id,
        deduplicate_data=True,
        index_column="PersonId",
        sort_by="OccurredTime",
        blocksize="256MB",
        kwargs_json=json.dumps(
            dict(
                usecols=[0, 1, 2, 3, 4],
                parse_dates=["OccurredTime"],
            )
        ),
    )

with get_session_with_context() as session:
    datasource = session.exec(
        select(DataSource).where(DataSource.id == datasource.id)
    ).one()
    display(datasource)

    train_request = TrainRequest(
        data_uuid=datasource.uuid,
        client_column="AccountId",
        target_column="DefinitionId",
        target="load*",
        predict_after=timedelta(seconds=20 * 24 * 60 * 60),
    )

    model = train_model(train_request=train_request, user=user, session=session)
    display(model)
    # Call exec_cli train_model

    b = BackgroundTasks()
    with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
        predicted = predict_model(
            model_uuid=model.uuid, user=user, session=session, background_tasks=b
        )
    display(predicted)
# Call exec_cli predict_model

'ofnsxhojge'

[INFO] botocore.credentials: Found credentials in environment variables.
[INFO] airt_service.data.datablob: DataBlob.from_local(): FromLocalResponse(uuid=UUID('dd1a7fc8-190c-459f-8a63-d95adea5fb62'), type='local', presigned={'url': 'https://kumaran-airt-service-eu-west-1.s3.amazonaws.com/', 'fields': {'key': '****************************************', 'x-amz-algorithm': 'AWS4-HMAC-SHA256', 'x-amz-credential': '********************/20230310/eu-west-1/s3/aws4_request', 'x-amz-date': '20230310T075535Z', 'policy': '************************************************************************************************************************************************************************************************************************************************************', 'x-amz-signature': '****************************'}})
[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://test-airt-service/account_312571_events
[INFO] airt.remote_path: S3Path._creat

Unnamed: 0_level_0,AccountId,DefinitionId,OccurredTime,OccurredTimeTicks,PersonId
__null_dask_index__,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,312571,loadTests2,2019-12-31 21:30:02,1577836802678,2
1,312571,loadTests3,2020-01-03 23:53:22,1578104602678,2
2,312571,loadTests1,2020-01-07 02:16:42,1578372402678,2
3,312571,loadTests2,2020-01-10 04:40:02,1578640202678,2
4,312571,loadTests3,2020-01-13 07:03:22,1578908002678,2


[Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_efd8hb46/_metadata'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_efd8hb46/_common_metadata'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_efd8hb46/file.csv'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_efd8hb46/part.3.parquet'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_efd8hb46/part.0.parquet'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_efd8hb46/part.1.parquet'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_efd8hb46/part.4.parquet'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_efd8hb46/part.2.parquet')]

[INFO] airt.remote_path: S3Path._clean_up(): removing local cache path /tmp/s3test-airt-serviceaccount_312571_events_cached_efd8hb46
[INFO] airt_service.data.csv: process_csv(datablob_id=101, datasource_id=59): processing user uploaded csv file for datablob_id=101 and uploading parquet back to S3 for datasource_id=59
[INFO] airt_service.data.csv: process_csv(datablob_id=101, datasource_id=59): step 1/4: downloading user uploaded file from bucket s3://kumaran-airt-service-eu-west-1/478/datablob/101
[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://kumaran-airt-service-eu-west-1/478/datablob/101
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3kumaran-airt-service-eu-west-1478datablob101_cached_lwrci1ig
[INFO] airt.remote_path: S3Path.__init__(): created object for accessing s3://kumaran-airt-service-eu-west-1/478/datablob/101 locally in /tmp/s3kumaran-airt-service-eu-west-1478datablob101_cached_lwrci1ig
[INFO

DataSource(id=59, uuid=UUID('0eba591a-bdbd-42ca-85f8-c609c6dbbf23'), hash='1dd8ee7a0f96a48110dec6e25891d18d', total_steps=1, completed_steps=1, folder_size=6619982, no_of_rows=498961, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path='s3://kumaran-airt-service-eu-west-1/478/datasource/59', created=datetime.datetime(2023, 3, 10, 7, 55, 53), user_id=478, pulled_on=datetime.datetime(2023, 3, 10, 7, 55, 59), tags=[])

Model(client_column='AccountId', completed_steps=0, datasource_id=59, cloud_provider=<CloudProvider.aws: 'aws'>, error=None, user_id=478, target_column='DefinitionId', region='eu-west-1', target='load*', disabled=False, predict_after=datetime.timedelta(days=20), created=datetime.datetime(2023, 3, 10, 7, 56, 18), timestamp_column=None, id=27, total_steps=5, uuid=UUID('8f289f40-c254-4e4a-b591-4621a58527a7'), path=None)

[INFO] airt_service.batch_job: create_batch_job(): command='predict 29', task='csv_processing'
[INFO] airt_service.batch_job_components.base: Entering FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job: batch_ctx=FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job_components.fastapi: FastAPIBatchJobContext.create_job(self=FastAPIBatchJobContext(task=csv_processing), command='predict 29', environment_vars={'AWS_ACCESS_KEY_ID': '********************', 'AWS_SECRET_ACCESS_KEY': '****************************************', 'AWS_DEFAULT_REGION': 'eu-west-1', 'AZURE_SUBSCRIPTION_ID': '************************************', 'AZURE_TENANT_ID': '************************************', 'AZURE_CLIENT_ID': '************************************', 'AZURE_CLIENT_SECRET': '****************************************', 'AZURE_STORAGE_ACCOUNT_PREFIX': 'kumsairtsdev', 'AZURE_RESOURCE_GROUP': 'kumaran-airt-service-dev', 'STORAGE_BUCKET_PREFIX': 'kumaran-airt-service'

Prediction(total_steps=3, model_id=27, created=datetime.datetime(2023, 3, 10, 7, 56, 18), uuid=UUID('859e21a4-a4f6-4009-ba81-57472cc38eb3'), path=None, completed_steps=0, region='eu-west-1', id=29, disabled=False, datasource_id=59, cloud_provider=<CloudProvider.aws: 'aws'>, error=None)

In [None]:
with get_session_with_context() as session:
    user_kumaran = session.exec(select(User).where(User.username == "kumaran")).one()

    cleanup(
        UserCleanupRequest(username=to_delete_username),
        user=user_kumaran,
        session=session,
    )

    with pytest.raises(NoResultFound):
        session.exec(select(User).where(User.username == to_delete_username)).one()

    with pytest.raises(HTTPException) as e:
        cleanup(
            UserCleanupRequest(username=to_delete_username),
            user=user_kumaran,
            session=session,
        )
    display(e)

[INFO] airt_service.cleanup: deleting predictions
[INFO] airt_service.cleanup: deleting models
[INFO] airt_service.cleanup: deleting datasources
[INFO] airt_service.cleanup: deleting datablobs
[INFO] airt_service.cleanup: deleting apikeys
[INFO] airt_service.cleanup: Deleting user files in s3://kumaran-airt-service-eu-west-1/478
[INFO] airt_service.cleanup: deleting user


<ExceptionInfo HTTPException(status_code=400, detail='Incorrect username. Please try again.') tblen=4>