Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Users | Separation of the "Token" class #408

Open
wants to merge 2 commits into
base: users
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions backend/users/backend/common/services/jwt/redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from config.settings import redis_config
from utils.db.redis_client import RedisAccessClient, RedisRefreshClient


class RedisAccessToken:
client = RedisAccessClient(
host=redis_config.REDIS_SHARED_HOST,
port=redis_config.REDIS_SHARED_PORT,
db=redis_config.REDIS_ACCESS_DB,
password=redis_config.REDIS_SHARED_PASSWORD,
)

def add(self, token: str, value: dict):
self.client.add_token(token=token, value=value)

def get_data(self, token: str) -> dict | None:
return self.client.is_token_exist(token)


class RedisBanRefreshToken:
client = RedisRefreshClient(
host=redis_config.REDIS_SHARED_HOST,
port=redis_config.REDIS_SHARED_PORT,
db=redis_config.REDIS_REFRESH_DB,
password=redis_config.REDIS_SHARED_PASSWORD,
)

def ban(self, token: str, value: dict = None):
self.client.add_token(token=token, value=value)

def get_data(self, token: str) -> dict | None:
return self.client.is_token_exist(token)
52 changes: 12 additions & 40 deletions backend/users/backend/common/services/jwt/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,15 @@
from base.tokens.token import BaseToken
from common.services.jwt.exceptions import PayloadError, TokenExpired
from common.services.jwt.mixins import JWTMixin
from common.services.jwt.redis import RedisAccessToken, RedisBanRefreshToken
from common.services.jwt.users_payload import PayloadFactory
from config.settings import redis_config
from django.conf import settings
from django.utils import timezone
from utils.db.redis_client import RedisAccessClient, RedisClient, RedisRefreshClient


class Token(BaseToken, JWTMixin):
redis_access_client = RedisAccessClient(
host=redis_config.REDIS_SHARED_HOST,
port=redis_config.REDIS_SHARED_PORT,
db=redis_config.REDIS_ACCESS_DB,
password=redis_config.REDIS_SHARED_PASSWORD,
)

redis_refresh_client = RedisRefreshClient(
host=redis_config.REDIS_SHARED_HOST,
port=redis_config.REDIS_SHARED_PORT,
db=redis_config.REDIS_REFRESH_DB,
password=redis_config.REDIS_SHARED_PASSWORD,
)
redis_access_client = RedisAccessToken()
redis_ban_refresh_client = RedisBanRefreshToken()

@staticmethod
def validate_user(user: BaseAbstractUser):
Expand Down Expand Up @@ -71,16 +59,14 @@ def generate_access_token(self, data: dict = None) -> str:
self.validate_payload_data(data)
default_payload = self.get_default_access_payload()
redis_payload = {
"token_type": "access",
**default_payload,
**data,
}
payload = {
"token_type": "access",
**default_payload,
}
access_token = self._encode(payload)
self.__add_access_to_redis(token=access_token, value=redis_payload)
self.redis_access_client.add(token=access_token, value=redis_payload)
return access_token

def generate_refresh_token(self, data: dict) -> str:
Expand All @@ -106,26 +92,24 @@ def generate_access_token_for_user(self, user: BaseAbstractUser) -> str:
user_payload = self.get_user_payload(user)
default_payload = self.get_default_access_payload()
redis_payload = {
"token_type": "access",
**default_payload,
**user_payload,
}
payload = {
"token_type": "access",
**default_payload,
}
access_token = self._encode(payload)
self.__add_access_to_redis(token=access_token, value=redis_payload)
self.redis_access_client.add(token=access_token, value=redis_payload)
return access_token

def generate_refresh_token_for_user(self, user: BaseAbstractUser) -> str:
self.validate_user(user)
user_payload = self.get_user_payload(user)
default_payload = self.get_default_refresh_payload()
payload = {
"token_type": "refresh",
"user_id": str(user.id),
"role": user._meta.app_label,
**default_payload,
**user_payload,
}
refresh_token = self._encode(payload)

Expand Down Expand Up @@ -161,20 +145,8 @@ def check_exp_left(self, token: str) -> int:
def check_signature(self, token: str) -> None:
self._decode(token)

@staticmethod
def __add_token_to_redis(redis_client: RedisClient, token: str, value: dict):
redis_client.add_token(token=token, value=value)

@classmethod
def __add_access_to_redis(cls, token: str, value: dict):
cls.__add_token_to_redis(redis_client=cls.redis_access_client, token=token, value=value)

@classmethod
def add_refresh_to_redis(cls, token: str, value: dict = None):
cls.__add_token_to_redis(redis_client=cls.redis_refresh_client, token=token, value=value)

def get_access_data(self, token: str) -> dict | None:
return self.redis_access_client.is_token_exist(token)

def get_refresh_data(self, token: str) -> dict | None:
return self.redis_refresh_client.is_token_exist(token)
# for backward compatibility
redis_refresh_client = redis_ban_refresh_client.client
add_refresh_to_redis = redis_ban_refresh_client.ban
get_access_data = redis_access_client.get_data
get_refresh_data = redis_ban_refresh_client.get_data