Skip to content

Commit

Permalink
fixup! Add trench
Browse files Browse the repository at this point in the history
  • Loading branch information
gagantrivedi committed May 26, 2024
1 parent 3f9cd3e commit ec9176e
Show file tree
Hide file tree
Showing 17 changed files with 157 additions and 245 deletions.
1 change: 1 addition & 0 deletions api/app/settings/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,7 @@
"DEFAULT_VALIDITY_PERIOD": 30,
"CONFIRM_BACKUP_CODES_REGENERATION_WITH_CODE": True,
"APPLICATION_ISSUER_NAME": "app.bullet-train.io",
"ENCRYPT_BACKUP_CODES": True,
"MFA_METHODS": {
"app": {
"VERBOSE_NAME": "TOTP App",
Expand Down
3 changes: 0 additions & 3 deletions api/custom_auth/mfa/backends/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ def dispatch_message(self):
}
return Response(data)

def validate_confirmation_code(self, code: str) -> bool:
return self.validate_code(code)

def validate_code(self, code: str) -> bool:
validity_period = settings.TRENCH_AUTH["MFA_METHODS"]["app"]["VALIDITY_PERIOD"]
return self._totp.verify(otp=code, valid_window=int(validity_period / 20))
8 changes: 1 addition & 7 deletions api/custom_auth/mfa/trench/command/activate_mfa_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from custom_auth.mfa.trench.command.replace_mfa_method_backup_codes import (
regenerate_backup_codes_for_mfa_method_command,
)
from custom_auth.mfa.trench.exceptions import MFAMethodDoesNotExistError
from custom_auth.mfa.trench.models import MFAMethod
from custom_auth.mfa.trench.utils import get_mfa_model

Expand All @@ -19,16 +18,11 @@ def __init__(
self._backup_codes_generator = backup_codes_generator

def execute(self, user_id: int, name: str, code: str) -> Set[str]:
rows_affected = self._mfa_model.objects.filter(
user_id=user_id, name=name
).update(
self._mfa_model.objects.filter(user_id=user_id, name=name).update(
is_active=True,
is_primary=not self._mfa_model.objects.primary_exists(user_id=user_id),
)

if rows_affected < 1:
raise MFAMethodDoesNotExistError()

backup_codes = regenerate_backup_codes_for_mfa_method_command(
user_id=user_id,
name=name,
Expand Down
10 changes: 1 addition & 9 deletions api/custom_auth/mfa/trench/command/deactivate_mfa_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@

from django.db.transaction import atomic

from custom_auth.mfa.trench.exceptions import (
DeactivationOfPrimaryMFAMethodError,
MFANotEnabledError,
)
from custom_auth.mfa.trench.exceptions import MFANotEnabledError
from custom_auth.mfa.trench.models import MFAMethod
from custom_auth.mfa.trench.utils import get_mfa_model

Expand All @@ -17,11 +14,6 @@ def __init__(self, mfa_model: Type[MFAMethod]) -> None:
@atomic
def execute(self, mfa_method_name: str, user_id: int) -> None:
mfa = self._mfa_model.objects.get_by_name(user_id=user_id, name=mfa_method_name)
number_of_active_mfa_methods = self._mfa_model.objects.filter(
user_id=user_id, is_active=True
).count()
if mfa.is_primary and number_of_active_mfa_methods > 1:
raise DeactivationOfPrimaryMFAMethodError()
if not mfa.is_active:
raise MFANotEnabledError()

Expand Down
9 changes: 4 additions & 5 deletions api/custom_auth/mfa/trench/command/generate_backup_codes.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from typing import Callable, Set

from django.conf import settings
from django.utils.crypto import get_random_string

from custom_auth.mfa.trench.settings import trench_settings


class GenerateBackupCodesCommand:
def __init__(self, random_string_generator: Callable) -> None:
self._random_string_generator = random_string_generator

def execute(
self,
quantity: int = trench_settings.BACKUP_CODES_QUANTITY,
length: int = trench_settings.BACKUP_CODES_LENGTH,
allowed_chars: str = trench_settings.BACKUP_CODES_CHARACTERS,
quantity: int = settings.TRENCH_AUTH["BACKUP_CODES_QUANTITY"],
length: int = settings.TRENCH_AUTH["BACKUP_CODES_LENGTH"],
allowed_chars: str = settings.TRENCH_AUTH["BACKUP_CODES_CHARACTERS"],
) -> Set[str]:
"""
Generates random encrypted backup codes.
Expand Down
16 changes: 2 additions & 14 deletions api/custom_auth/mfa/trench/command/remove_backup_code.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
from typing import Any, Set
from typing import Any, Optional, Set

from django.contrib.auth.hashers import check_password

from custom_auth.mfa.trench.exceptions import (
InvalidCodeError,
MFAMethodDoesNotExistError,
)
from custom_auth.mfa.trench.models import MFAMethod
from custom_auth.mfa.trench.settings import trench_settings


def remove_backup_code_command(user_id: Any, method_name: str, code: str) -> None:
Expand All @@ -16,8 +11,6 @@ def remove_backup_code_command(user_id: Any, method_name: str, code: str) -> Non
.values_list("_backup_codes", flat=True)
.first()
)
if serialized_codes is None:
raise MFAMethodDoesNotExistError()
codes = MFAMethod._BACKUP_CODES_DELIMITER.join(
_remove_code_from_set(
backup_codes=set(serialized_codes.split(MFAMethod._BACKUP_CODES_DELIMITER)),
Expand All @@ -29,13 +22,8 @@ def remove_backup_code_command(user_id: Any, method_name: str, code: str) -> Non
)


def _remove_code_from_set(backup_codes: Set[str], code: str) -> Set[str]:
settings = trench_settings
if not settings.ENCRYPT_BACKUP_CODES:
backup_codes.remove(code)
return backup_codes
def _remove_code_from_set(backup_codes: Set[str], code: str) -> Optional[Set[str]]:
for backup_code in backup_codes:
if check_password(code, backup_code):
backup_codes.remove(backup_code)
return backup_codes
raise InvalidCodeError()
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
from custom_auth.mfa.trench.command.generate_backup_codes import (
generate_backup_codes_command,
)
from custom_auth.mfa.trench.exceptions import MFAMethodDoesNotExistError
from custom_auth.mfa.trench.models import MFAMethod
from custom_auth.mfa.trench.settings import trench_settings
from custom_auth.mfa.trench.utils import get_mfa_model


Expand All @@ -26,25 +24,19 @@ def __init__(

def execute(self, user_id: int, name: str) -> Set[str]:
backup_codes = self._codes_generator()
rows_affected = self._mfa_model.objects.filter(
user_id=user_id, name=name
).update(
self._mfa_model.objects.filter(user_id=user_id, name=name).update(
_backup_codes=MFAMethod._BACKUP_CODES_DELIMITER.join(
[self._code_hasher(backup_code) for backup_code in backup_codes]
if self._requires_encryption
else backup_codes
),
)

if rows_affected < 1:
raise MFAMethodDoesNotExistError()

return backup_codes


regenerate_backup_codes_for_mfa_method_command = (
RegenerateBackupCodesForMFAMethodCommand(
requires_encryption=trench_settings.ENCRYPT_BACKUP_CODES,
requires_encryption=True,
mfa_model=get_mfa_model(),
code_hasher=make_password,
codes_generator=generate_backup_codes_command,
Expand Down
5 changes: 0 additions & 5 deletions api/custom_auth/mfa/trench/command/validate_backup_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,8 @@

from django.contrib.auth.hashers import check_password

from custom_auth.mfa.trench.settings import trench_settings


def validate_backup_code_command(value: str, backup_codes: Iterable) -> Optional[str]:
settings = trench_settings
if not settings.ENCRYPT_BACKUP_CODES:
return value if value in backup_codes else None
for backup_code in backup_codes:
if check_password(value, backup_code):
return backup_code
Expand Down
21 changes: 0 additions & 21 deletions api/custom_auth/mfa/trench/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
from django.core.exceptions import ImproperlyConfigured
from django.utils.translation import gettext_lazy as _
from rest_framework.serializers import ValidationError


class MethodHandlerMissingError(ImproperlyConfigured):
def __init__(self, method_name: str) -> None:
super().__init__(f"Missing handler in {method_name} configuration.")


class MFAValidationError(ValidationError):
def __str__(self) -> str:
return ", ".join(detail for detail in self.detail)
Expand All @@ -21,11 +15,6 @@ def __init__(self) -> None:
)


class OTPCodeMissingError(MFAValidationError):
def __init__(self) -> None:
super().__init__(detail=_("OTP code not provided."), code="otp_code_missing")


class MFAMethodDoesNotExistError(MFAValidationError):
def __init__(self) -> None:
super().__init__(
Expand All @@ -42,16 +31,6 @@ def __init__(self) -> None:
)


class DeactivationOfPrimaryMFAMethodError(MFAValidationError):
def __init__(self) -> None:
super().__init__(
detail=_(
"Deactivation of MFA method that is set as primary is not allowed."
),
code="deactivation_of_primary",
)


class MFANotEnabledError(MFAValidationError):
def __init__(self) -> None:
super().__init__(detail=_("2FA is not enabled."), code="not_enabled")
Expand Down
7 changes: 0 additions & 7 deletions api/custom_auth/mfa/trench/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,6 @@ class Meta:

objects = MFAUserMethodManager()

def __str__(self) -> str:
return f"{self.name} (User id: {self.user_id})"

@property
def backup_codes(self) -> Iterable[str]:
return self._backup_codes.split(self._BACKUP_CODES_DELIMITER)

@backup_codes.setter
def backup_codes(self, codes: Iterable) -> None:
self._backup_codes = self._BACKUP_CODES_DELIMITER.join(codes)
24 changes: 1 addition & 23 deletions api/custom_auth/mfa/trench/responses.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
from rest_framework.response import Response
from rest_framework.status import (
HTTP_200_OK,
HTTP_400_BAD_REQUEST,
HTTP_422_UNPROCESSABLE_ENTITY,
)
from rest_framework.status import HTTP_400_BAD_REQUEST

from custom_auth.mfa.trench.exceptions import MFAValidationError

Expand All @@ -12,24 +8,6 @@ class DispatchResponse(Response):
_FIELD_DETAILS = "details"


class SuccessfulDispatchResponse(DispatchResponse):
def __init__(
self, details: str, status: str = HTTP_200_OK, *args, **kwargs
) -> None:
super().__init__(
data={self._FIELD_DETAILS: details}, status=status, *args, **kwargs
)


class FailedDispatchResponse(DispatchResponse):
def __init__(
self, details: str, status: str = HTTP_422_UNPROCESSABLE_ENTITY, *args, **kwargs
) -> None:
super().__init__(
data={self._FIELD_DETAILS: details}, status=status, *args, **kwargs
)


class ErrorResponse(Response):
_FIELD_ERROR = "error"

Expand Down
51 changes: 2 additions & 49 deletions api/custom_auth/mfa/trench/serializers.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,40 @@
from abc import abstractmethod

from django.contrib.auth import get_user_model
from django.contrib.auth.models import AbstractUser
from rest_framework.fields import CharField
from rest_framework.serializers import ModelSerializer, Serializer

from custom_auth.mfa.trench.command.remove_backup_code import (
remove_backup_code_command,
)
from custom_auth.mfa.trench.command.validate_backup_code import (
validate_backup_code_command,
)
from custom_auth.mfa.trench.exceptions import (
CodeInvalidOrExpiredError,
MFAMethodAlreadyActiveError,
MFANotEnabledError,
OTPCodeMissingError,
)
from custom_auth.mfa.trench.models import MFAMethod
from custom_auth.mfa.trench.settings import trench_settings
from custom_auth.mfa.trench.utils import get_mfa_handler, get_mfa_model

User: AbstractUser = get_user_model()


class ProtectedActionValidator(Serializer):
class MFAMethodActivationConfirmationValidator(Serializer):
code = CharField()

@staticmethod
def _get_validation_method_name() -> str:
return "validate_code"

@staticmethod
@abstractmethod
def _validate_mfa_method(mfa: MFAMethod) -> None:
raise NotImplementedError

def __init__(self, mfa_method_name: str, user: User, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._user = user
self._mfa_method_name = mfa_method_name

def validate_code(self, value: str) -> str:
if not value:
raise OTPCodeMissingError()
mfa_model = get_mfa_model()
mfa = mfa_model.objects.get_by_name(
user_id=self._user.id, name=self._mfa_method_name
)
self._validate_mfa_method(mfa)

validated_backup_code = validate_backup_code_command(
value=value, backup_codes=mfa.backup_codes
)

handler = get_mfa_handler(mfa)
validation_method = getattr(handler, self._get_validation_method_name())
if validation_method(value):
return value

if validated_backup_code:
remove_backup_code_command(
user_id=mfa.user_id, method_name=mfa.name, code=value
)
if handler.validate_code(value):
return value

raise CodeInvalidOrExpiredError()


class MFAMethodDeactivationValidator(ProtectedActionValidator):
code = CharField(required=trench_settings.CONFIRM_DISABLE_WITH_CODE)

@staticmethod
def _validate_mfa_method(mfa: MFAMethod) -> None:
if not mfa.is_active:
raise MFANotEnabledError()


class MFAMethodActivationConfirmationValidator(ProtectedActionValidator):
@staticmethod
def _get_validation_method_name() -> str:
return "validate_confirmation_code"

@staticmethod
def _validate_mfa_method(mfa: MFAMethod) -> None:
if mfa.is_active:
Expand Down
Loading

0 comments on commit ec9176e

Please sign in to comment.