From acf222498aedc9b093b2e66586d25e85b6cd0b94 Mon Sep 17 00:00:00 2001 From: Sebastian Allard Date: Mon, 11 Mar 2024 09:50:59 +0100 Subject: [PATCH 1/2] Add /auth endpoint (#402)(minor) Add endpoint for exchanging authorization code. --- Dockerfile | 3 + ...8_febdb4e78bb5_add_refresh_token_column.py | 24 ++++++ tests/cli/test_cli_core.py | 1 - tests/integration/conftest.py | 18 ----- tests/integration/endpoints/conftest.py | 20 +++++ tests/integration/services/conftest.py | 74 +++++++++++++++++++ .../services/test_authentication_service.py | 18 +++++ .../services/test_encryption_service.py | 16 ++++ .../clients/authentication_client/__init__.py | 0 .../authentication_client/dtos/__init__.py | 0 .../dtos/tokens_request.py | 9 +++ .../dtos/tokens_response.py | 9 +++ .../authentication_client/exceptions.py | 2 + .../google_oauth_client.py | 31 ++++++++ .../clients/google_api_client/__init__.py | 0 .../clients/google_api_client/exceptions.py | 2 + .../google_api_client/google_api_client.py | 23 ++++++ trailblazer/containers.py | 30 ++++++++ trailblazer/dto/authentication/__init__.py | 0 .../authentication/code_exchange_request.py | 5 ++ trailblazer/server/api.py | 19 +++++ .../authentication_service/__init__.py | 0 .../authentication_service.py | 36 +++++++++ .../authentication_service/exceptions.py | 6 ++ .../services/encryption_service/__init__.py | 0 .../encryption_service/encryption_service.py | 12 +++ .../services/encryption_service/utils.py | 35 +++++++++ trailblazer/store/crud/read.py | 7 ++ trailblazer/store/crud/update.py | 6 ++ trailblazer/store/filters/user_filters.py | 8 ++ trailblazer/store/models.py | 1 + 31 files changed, 396 insertions(+), 19 deletions(-) create mode 100644 alembic/versions/2024_03_08_febdb4e78bb5_add_refresh_token_column.py create mode 100644 tests/integration/endpoints/conftest.py create mode 100644 tests/integration/services/test_authentication_service.py create mode 100644 tests/integration/services/test_encryption_service.py create mode 100644 trailblazer/clients/authentication_client/__init__.py create mode 100644 trailblazer/clients/authentication_client/dtos/__init__.py create mode 100644 trailblazer/clients/authentication_client/dtos/tokens_request.py create mode 100644 trailblazer/clients/authentication_client/dtos/tokens_response.py create mode 100644 trailblazer/clients/authentication_client/exceptions.py create mode 100644 trailblazer/clients/authentication_client/google_oauth_client.py create mode 100644 trailblazer/clients/google_api_client/__init__.py create mode 100644 trailblazer/clients/google_api_client/exceptions.py create mode 100644 trailblazer/clients/google_api_client/google_api_client.py create mode 100644 trailblazer/dto/authentication/__init__.py create mode 100644 trailblazer/dto/authentication/code_exchange_request.py create mode 100644 trailblazer/services/authentication_service/__init__.py create mode 100644 trailblazer/services/authentication_service/authentication_service.py create mode 100644 trailblazer/services/authentication_service/exceptions.py create mode 100644 trailblazer/services/encryption_service/__init__.py create mode 100644 trailblazer/services/encryption_service/encryption_service.py create mode 100644 trailblazer/services/encryption_service/utils.py diff --git a/Dockerfile b/Dockerfile index 49887f81..3c00c337 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,6 +8,9 @@ ENV GUNICORN_TIMEOUT=400 ENV SECRET_KEY="Authkey" ENV SQLALCHEMY_DATABASE_URI="sqlite:///:memory:" ENV ANALYSIS_HOST="a_host" +ENV GOOGLE_CLIENT_ID="a_client_id" +ENV GOOGLE_CLIENT_SECRET="a_client_secret" +ENV GOOGLE_REDIRECT_URI="http://localhost:8000/auth" WORKDIR /home/src/app COPY . /home/src/app diff --git a/alembic/versions/2024_03_08_febdb4e78bb5_add_refresh_token_column.py b/alembic/versions/2024_03_08_febdb4e78bb5_add_refresh_token_column.py new file mode 100644 index 00000000..2b8921d7 --- /dev/null +++ b/alembic/versions/2024_03_08_febdb4e78bb5_add_refresh_token_column.py @@ -0,0 +1,24 @@ +"""Add refresh token column + +Revision ID: febdb4e78bb5 +Revises: e907e651fb9e +Create Date: 2024-03-08 10:38:28.151275 + +""" + +# revision identifiers, used by Alembic. +revision = "febdb4e78bb5" +down_revision = "e907e651fb9e" +branch_labels = None +depends_on = None + +from alembic import op +import sqlalchemy as sa + + +def upgrade(): + op.add_column("user", sa.Column("refresh_token", sa.Text, nullable=True)) + + +def downgrade(): + op.drop_column("user", "refresh_token") diff --git a/tests/cli/test_cli_core.py b/tests/cli/test_cli_core.py index d70f6446..f1806ce9 100644 --- a/tests/cli/test_cli_core.py +++ b/tests/cli/test_cli_core.py @@ -21,7 +21,6 @@ unarchive_user, ) from trailblazer.constants import CharacterFormat, SlurmJobStatus, TrailblazerStatus -from trailblazer.containers import Container from trailblazer.store.models import Analysis FUNC_GET_SLURM_SQUEUE_OUTPUT_PATH: str = "trailblazer.store.crud.update.get_slurm_squeue_output" diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index d2d9aa4b..42be1e4c 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,11 +1,7 @@ import datetime import uuid -from typing import Generator -from unittest.mock import patch import pytest -from flask import Flask -from flask.testing import FlaskClient from sqlalchemy.orm import Session from trailblazer.constants import ( @@ -16,22 +12,8 @@ TrailblazerStatus, WorkflowManager, ) -from trailblazer.server.app import app from trailblazer.store.database import get_session from trailblazer.store.models import Analysis, Job -from trailblazer.store.store import Store - - -@pytest.fixture -def flask_app(store: Store): - yield app - - -@pytest.fixture -def client(flask_app: Flask) -> Generator[FlaskClient, None, None]: - # Bypass authentication - with patch.object(flask_app, "before_request_funcs", new={}): - yield flask_app.test_client() @pytest.fixture diff --git a/tests/integration/endpoints/conftest.py b/tests/integration/endpoints/conftest.py new file mode 100644 index 00000000..62f760af --- /dev/null +++ b/tests/integration/endpoints/conftest.py @@ -0,0 +1,20 @@ +from typing import Generator +from unittest.mock import patch +from flask import Flask +from flask.testing import FlaskClient +import pytest + +from trailblazer.server.app import app +from trailblazer.store.store import Store + + +@pytest.fixture +def flask_app(user_store: Store): + yield app + + +@pytest.fixture +def client(flask_app: Flask) -> Generator[FlaskClient, None, None]: + # Bypass authentication + with patch.object(flask_app, "before_request_funcs", new={}): + yield flask_app.test_client() diff --git a/tests/integration/services/conftest.py b/tests/integration/services/conftest.py index 97d6ec6e..4fcb389d 100644 --- a/tests/integration/services/conftest.py +++ b/tests/integration/services/conftest.py @@ -1,8 +1,16 @@ +import base64 from datetime import datetime +import os import pytest +from requests_mock import Mocker + +from trailblazer.clients.authentication_client.google_oauth_client import GoogleOAuthClient +from trailblazer.clients.google_api_client.google_api_client import GoogleAPIClient from trailblazer.clients.slurm_cli_client.slurm_cli_client import SlurmCLIClient from trailblazer.constants import TrailblazerStatus +from trailblazer.services.authentication_service.authentication_service import AuthenticationService +from trailblazer.services.encryption_service.encryption_service import EncryptionService from trailblazer.services.job_service import JobService from trailblazer.services.slurm.dtos import SlurmJobInfo from trailblazer.services.slurm.slurm_cli_service.slurm_cli_service import SlurmCLIService @@ -26,3 +34,69 @@ def job_service(analysis_store: Store): slurm_client = SlurmCLIClient("host") slurm_service = SlurmCLIService(slurm_client) return JobService(slurm_service=slurm_service, store=analysis_store) + + +@pytest.fixture +def encryption_service() -> EncryptionService: + key: bytes = os.urandom(32) + secret_key: str = base64.b64encode(key).decode() + return EncryptionService(secret_key) + + +@pytest.fixture +def google_oauth_client(google_oauth_response: dict, mock_request: Mocker) -> GoogleOAuthClient: + + token_uri = "https://oauth2.googleapis.com/token" + mock_request.post(token_uri, json=google_oauth_response) + + return GoogleOAuthClient( + client_id="client_id", + client_secret="client_secret", + redirect_uri="redirect_uri", + token_uri=token_uri, + ) + + +@pytest.fixture +def google_user_info_response(user_email: str) -> dict: + return {"email": f"{user_email}"} + + +@pytest.fixture +def google_api_client(mock_request: Mocker, google_user_info_response: dict) -> GoogleAPIClient: + base_url = "https://www.googleapis.com" + user_info_endpoint = f"{base_url}/oauth2/v1/userinfo" + mock_request.get(user_info_endpoint, json=google_user_info_response) + return GoogleAPIClient("https://www.googleapis.com") + + +@pytest.fixture +def authentication_service( + encryption_service: EncryptionService, + google_oauth_client: GoogleOAuthClient, + google_api_client: GoogleAPIClient, + user_store: Store, +) -> AuthenticationService: + return AuthenticationService( + encryption_service=encryption_service, + store=user_store, + google_oauth_client=google_oauth_client, + google_api_client=google_api_client, + ) + + +@pytest.fixture +def google_oauth_response() -> dict: + return { + "access_token": "access_token", + "token_type": "token_type", + "expires_in": 3600, + "refresh_token": "refresh_token", + "scope": "scope", + } + + +@pytest.fixture +def mock_request(): + with Mocker() as mock: + yield mock diff --git a/tests/integration/services/test_authentication_service.py b/tests/integration/services/test_authentication_service.py new file mode 100644 index 00000000..04986671 --- /dev/null +++ b/tests/integration/services/test_authentication_service.py @@ -0,0 +1,18 @@ +from requests_mock import Mocker + +from trailblazer.services.authentication_service.authentication_service import AuthenticationService +from trailblazer.store.models import User + + +def test_exchange_code(authentication_service: AuthenticationService, user_email: str): + # GIVEN an authentication service + + # WHEN exchanging the authorization code + token: str = authentication_service.authenticate("auth_code") + + # THEN an access token is returned + assert token + + # THEN the refresh token is stored on the user + user: User = authentication_service.store.get_user(user_email) + assert user.refresh_token diff --git a/tests/integration/services/test_encryption_service.py b/tests/integration/services/test_encryption_service.py new file mode 100644 index 00000000..3815bfef --- /dev/null +++ b/tests/integration/services/test_encryption_service.py @@ -0,0 +1,16 @@ +from trailblazer.services.encryption_service.encryption_service import EncryptionService + + +def test_encryption(encryption_service: EncryptionService): + # GIVEN a plain text + plain_text = "my secret" + + # WHEN encrypting and decrypting the plain text + encrypted_text: str = encryption_service.encrypt(plain_text) + decrypted_text: str = encryption_service.decrypt(encrypted_text) + + # THEN the decrypted text is the same as the plain text + assert decrypted_text == plain_text + + # THEN the encrypted text is different from the plain text + assert encrypted_text != plain_text diff --git a/trailblazer/clients/authentication_client/__init__.py b/trailblazer/clients/authentication_client/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/trailblazer/clients/authentication_client/dtos/__init__.py b/trailblazer/clients/authentication_client/dtos/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/trailblazer/clients/authentication_client/dtos/tokens_request.py b/trailblazer/clients/authentication_client/dtos/tokens_request.py new file mode 100644 index 00000000..e24694cd --- /dev/null +++ b/trailblazer/clients/authentication_client/dtos/tokens_request.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + + +class GetTokensRequest(BaseModel): + client_id: str + client_secret: str + code: str + grant_type: str = "authorization_code" + redirect_uri: str diff --git a/trailblazer/clients/authentication_client/dtos/tokens_response.py b/trailblazer/clients/authentication_client/dtos/tokens_response.py new file mode 100644 index 00000000..7dbae9bd --- /dev/null +++ b/trailblazer/clients/authentication_client/dtos/tokens_response.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + + +class TokensResponse(BaseModel): + access_token: str + expires_in: int + refresh_token: str + scope: str + token_type: str diff --git a/trailblazer/clients/authentication_client/exceptions.py b/trailblazer/clients/authentication_client/exceptions.py new file mode 100644 index 00000000..a9a55305 --- /dev/null +++ b/trailblazer/clients/authentication_client/exceptions.py @@ -0,0 +1,2 @@ +class GoogleOAuthClientError(Exception): + pass diff --git a/trailblazer/clients/authentication_client/google_oauth_client.py b/trailblazer/clients/authentication_client/google_oauth_client.py new file mode 100644 index 00000000..38304ad8 --- /dev/null +++ b/trailblazer/clients/authentication_client/google_oauth_client.py @@ -0,0 +1,31 @@ +import requests + +from trailblazer.clients.authentication_client.dtos.tokens_request import GetTokensRequest +from trailblazer.clients.authentication_client.dtos.tokens_response import TokensResponse +from trailblazer.clients.authentication_client.exceptions import GoogleOAuthClientError + + +class GoogleOAuthClient: + + def __init__(self, client_id: str, client_secret: str, redirect_uri: str, token_uri: str): + self.client_id = client_id + self.client_secret = client_secret + self.token_uri = token_uri + self.redirect_uri = redirect_uri + + def get_tokens(self, authorization_code: str) -> TokensResponse: + """Exchange the authorization code for an access token and refresh token.""" + request = GetTokensRequest( + client_id=self.client_id, + client_secret=self.client_secret, + code=authorization_code, + redirect_uri=self.redirect_uri, + ) + data: dict = request.model_dump() + + response = requests.post(self.token_uri, data=data) + + if not response.ok: + raise GoogleOAuthClientError(response.text) + + return TokensResponse.model_validate(response.json()) diff --git a/trailblazer/clients/google_api_client/__init__.py b/trailblazer/clients/google_api_client/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/trailblazer/clients/google_api_client/exceptions.py b/trailblazer/clients/google_api_client/exceptions.py new file mode 100644 index 00000000..8ef4c86a --- /dev/null +++ b/trailblazer/clients/google_api_client/exceptions.py @@ -0,0 +1,2 @@ +class GoogleAPIClientError(Exception): + pass diff --git a/trailblazer/clients/google_api_client/google_api_client.py b/trailblazer/clients/google_api_client/google_api_client.py new file mode 100644 index 00000000..efaef8e1 --- /dev/null +++ b/trailblazer/clients/google_api_client/google_api_client.py @@ -0,0 +1,23 @@ +import requests + +from trailblazer.clients.google_api_client.exceptions import GoogleAPIClientError + + +class GoogleAPIClient: + + def __init__(self, base_url: str): + self.base_url = base_url + + def _get_headers(self, access_token: str) -> dict: + return {"Authorization": f"Bearer {access_token}"} + + def get_user_email(self, access_token: str) -> str: + """Get the user email for the given access token.""" + endpoint: str = f"{self.base_url}/oauth2/v1/userinfo" + headers: dict = self._get_headers(access_token) + response = requests.get(endpoint, headers=headers) + + if not response.ok: + raise GoogleAPIClientError(response.text) + + return response.json()["email"] diff --git a/trailblazer/containers.py b/trailblazer/containers.py index f861701d..c834ec0d 100644 --- a/trailblazer/containers.py +++ b/trailblazer/containers.py @@ -1,8 +1,12 @@ import os from dependency_injector import containers, providers +from trailblazer.clients.authentication_client.google_oauth_client import GoogleOAuthClient +from trailblazer.clients.google_api_client.google_api_client import GoogleAPIClient from trailblazer.clients.slurm_cli_client.slurm_cli_client import SlurmCLIClient from trailblazer.services.analysis_service.analysis_service import AnalysisService +from trailblazer.services.authentication_service.authentication_service import AuthenticationService +from trailblazer.services.encryption_service.encryption_service import EncryptionService from trailblazer.services.job_service import JobService from trailblazer.services.slurm.slurm_cli_service.slurm_cli_service import SlurmCLIService from trailblazer.store.store import Store @@ -10,6 +14,22 @@ class Container(containers.DeclarativeContainer): slurm_host: str | None = os.environ.get("ANALYSIS_HOST") + oauth_client_id: str = os.environ.get("GOOGLE_CLIENT_ID") + oauth_client_secret: str = os.environ.get("GOOGLE_CLIENT_SECRET") + oauth_redirect_uri: str = os.environ.get("GOOGLE_REDIRECT_URI") + oauth_token_uri: str = os.environ.get("TOKEN_URI") + encryption_key: str = os.environ.get("SECRET_KEY") + google_api_base_url: str = os.environ.get("GOOGLE_API_BASE_URL") + + google_api_client = GoogleAPIClient(google_api_base_url) + + google_oauth_client = providers.Singleton( + GoogleOAuthClient, + client_id=oauth_client_id, + client_secret=oauth_client_secret, + redirect_uri=oauth_redirect_uri, + token_uri=oauth_token_uri, + ) store = providers.Singleton(Store) slurm_client = providers.Singleton(SlurmCLIClient, host=slurm_host) @@ -17,3 +37,13 @@ class Container(containers.DeclarativeContainer): job_service = providers.Factory(JobService, store=store, slurm_service=slurm_service) analysis_service = providers.Factory(AnalysisService, store=store) + + encryption_service = providers.Singleton(EncryptionService, secret_key=encryption_key) + + auth_service = providers.Singleton( + AuthenticationService, + google_oauth_client=google_oauth_client, + google_api_client=google_api_client, + encryption_service=encryption_service, + store=store, + ) diff --git a/trailblazer/dto/authentication/__init__.py b/trailblazer/dto/authentication/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/trailblazer/dto/authentication/code_exchange_request.py b/trailblazer/dto/authentication/code_exchange_request.py new file mode 100644 index 00000000..fd5c97b7 --- /dev/null +++ b/trailblazer/dto/authentication/code_exchange_request.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class CodeExchangeRequest(BaseModel): + code: str diff --git a/trailblazer/server/api.py b/trailblazer/server/api.py index 07680a7a..eaf2c049 100644 --- a/trailblazer/server/api.py +++ b/trailblazer/server/api.py @@ -27,6 +27,7 @@ FailedJobsResponse, ) from trailblazer.dto.analyses_response import UpdateAnalysesResponse +from trailblazer.dto.authentication.code_exchange_request import CodeExchangeRequest from trailblazer.dto.create_analysis_request import CreateAnalysisRequest from trailblazer.dto.summaries_request import SummariesRequest from trailblazer.dto.summaries_response import SummariesResponse @@ -38,6 +39,8 @@ stringify_timestamps, ) from trailblazer.services.analysis_service.analysis_service import AnalysisService +from trailblazer.services.authentication_service.authentication_service import AuthenticationService +from trailblazer.services.authentication_service.exceptions import AuthenticationError from trailblazer.services.job_service import JobService from trailblazer.store.models import Info @@ -47,6 +50,8 @@ @blueprint.before_request def before_request(): """Authentication that is run before processing requests to the application""" + if request.endpoint == "api.authenticate": + return if request.method == "OPTIONS": return make_response(jsonify(ok=True), 204) if os.environ.get("SCOPE") == "DEVELOPMENT": @@ -63,6 +68,20 @@ def before_request(): return abort(403, f"{user_data['email']} doesn't have access") +@blueprint.route("/auth", methods=["POST"]) +@inject +def authenticate(auth_service: AuthenticationService = Provide[Container.auth_service]): + """Exchange authorization code for an access token.""" + try: + request_data = CodeExchangeRequest.model_validate(request.json) + token: str = auth_service.authenticate(request_data.code) + return jsonify({"access_token": token}), HTTPStatus.OK + except ValidationError as error: + return jsonify(error=str(error)), HTTPStatus.BAD_REQUEST + except AuthenticationError: + return jsonify("User not allowed"), HTTPStatus.FORBIDDEN + + @blueprint.route("/analyses", methods=["GET"]) @inject def get_analyses(analysis_service: AnalysisService = Provide[Container.analysis_service]): diff --git a/trailblazer/services/authentication_service/__init__.py b/trailblazer/services/authentication_service/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/trailblazer/services/authentication_service/authentication_service.py b/trailblazer/services/authentication_service/authentication_service.py new file mode 100644 index 00000000..8f1b3878 --- /dev/null +++ b/trailblazer/services/authentication_service/authentication_service.py @@ -0,0 +1,36 @@ +from trailblazer.clients.authentication_client.dtos.tokens_response import TokensResponse +from trailblazer.clients.authentication_client.google_oauth_client import GoogleOAuthClient +from trailblazer.clients.google_api_client.google_api_client import GoogleAPIClient +from trailblazer.services.authentication_service.exceptions import UserNotFoundError +from trailblazer.services.encryption_service.encryption_service import EncryptionService +from trailblazer.store.models import User +from trailblazer.store.store import Store + + +class AuthenticationService: + + def __init__( + self, + google_oauth_client: GoogleOAuthClient, + google_api_client: GoogleAPIClient, + encryption_service: EncryptionService, + store: Store, + ): + self.google_oauth_client = google_oauth_client + self.google_api_client = google_api_client + self.encryption_service = encryption_service + self.store = store + + def authenticate(self, authorization_code: str) -> str: + """Exchange the authorization code for an access token.""" + tokens: TokensResponse = self.google_oauth_client.get_tokens(authorization_code) + user_email: str = self.google_api_client.get_user_email(tokens.access_token) + user: User | None = self.store.get_user(user_email) + + if not user: + raise UserNotFoundError + + encrypted_token: str = self.encryption_service.encrypt(tokens.refresh_token) + self.store.update_user_token(user_id=user.id, refresh_token=encrypted_token) + + return tokens.access_token diff --git a/trailblazer/services/authentication_service/exceptions.py b/trailblazer/services/authentication_service/exceptions.py new file mode 100644 index 00000000..e3911685 --- /dev/null +++ b/trailblazer/services/authentication_service/exceptions.py @@ -0,0 +1,6 @@ +class AuthenticationError(Exception): + pass + + +class UserNotFoundError(AuthenticationError): + pass diff --git a/trailblazer/services/encryption_service/__init__.py b/trailblazer/services/encryption_service/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/trailblazer/services/encryption_service/encryption_service.py b/trailblazer/services/encryption_service/encryption_service.py new file mode 100644 index 00000000..422f8e7d --- /dev/null +++ b/trailblazer/services/encryption_service/encryption_service.py @@ -0,0 +1,12 @@ +from trailblazer.services.encryption_service.utils import decrypt_data, encrypt_data + + +class EncryptionService: + def __init__(self, secret_key: str): + self.secret_key = secret_key + + def encrypt(self, data: str) -> str: + return encrypt_data(data=data, secret_key=self.secret_key) + + def decrypt(self, data: str) -> str: + return decrypt_data(data=data, secret_key=self.secret_key) diff --git a/trailblazer/services/encryption_service/utils.py b/trailblazer/services/encryption_service/utils.py new file mode 100644 index 00000000..d6ed0f37 --- /dev/null +++ b/trailblazer/services/encryption_service/utils.py @@ -0,0 +1,35 @@ +import base64 +import os +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +NONCE_SIZE = 12 + + +def encrypt_data(data: str, secret_key: str) -> str: + data: bytes = data.encode() + cipher_text: bytes = _encrypt_with_aes(data=data, secret_key=secret_key) + return _bytes_to_string(cipher_text) + + +def decrypt_data(data: str, secret_key: str) -> str: + data: bytes = base64.b64decode(data) + plain_text: bytes = _decrypt_with_aes(data=data, secret_key=secret_key) + return plain_text.decode() + + +def _encrypt_with_aes(data: bytes, secret_key: str) -> bytes: + cipher = AESGCM(base64.b64decode(secret_key)) + nonce: bytes = os.urandom(NONCE_SIZE) + cipher_text: bytes = cipher.encrypt(nonce=nonce, data=data, associated_data=None) + return nonce + cipher_text + + +def _decrypt_with_aes(data: bytes, secret_key: str) -> bytes: + cipher = AESGCM(base64.b64decode(secret_key)) + nonce: bytes = data[:NONCE_SIZE] + cipher_text: bytes = data[NONCE_SIZE:] + return cipher.decrypt(nonce=nonce, data=cipher_text, associated_data=None) + + +def _bytes_to_string(data: bytes) -> str: + return base64.b64encode(data).decode() diff --git a/trailblazer/store/crud/read.py b/trailblazer/store/crud/read.py index b8958320..2b2dee20 100644 --- a/trailblazer/store/crud/read.py +++ b/trailblazer/store/crud/read.py @@ -139,6 +139,13 @@ def get_user( email=email, ).first() + def get_user_by_id(self, user_id: int) -> User | None: + return apply_user_filter( + filter_functions=[UserFilter.BY_ID], + users=self.get_query(table=User), + id=user_id, + ).first() + def get_users( self, name: str = None, diff --git a/trailblazer/store/crud/update.py b/trailblazer/store/crud/update.py index de5918a2..4f202915 100644 --- a/trailblazer/store/crud/update.py +++ b/trailblazer/store/crud/update.py @@ -283,3 +283,9 @@ def update_job(self, job_id: int, job_info: SlurmJobInfo) -> Job: job.started_at = job_info.started_at session: Session = get_session() session.commit() + + def update_user_token(self, refresh_token: str, user_id: int) -> None: + user: User | None = self.get_user_by_id(user_id) + user.refresh_token = refresh_token + session: Session = get_session() + session.commit() diff --git a/trailblazer/store/filters/user_filters.py b/trailblazer/store/filters/user_filters.py index d7a63588..a2ace256 100644 --- a/trailblazer/store/filters/user_filters.py +++ b/trailblazer/store/filters/user_filters.py @@ -26,9 +26,15 @@ def filter_users_by_is_not_archived(users: Query, **kwargs) -> Query: return users.filter(User.is_archived.is_(False)) +def filter_users_by_id(users: Query, id: int, **kwargs) -> Query: + """Filter users by id.""" + return users.filter(User.id == id) + + class UserFilter(Enum): """Define User filter functions.""" + BY_ID: Callable = filter_users_by_id BY_CONTAINS_EMAIL: Callable = filter_users_by_contains_email BY_CONTAINS_NAME: Callable = filter_users_by_contains_name BY_EMAIL: Callable = filter_users_by_email @@ -39,6 +45,7 @@ def apply_user_filter( users: Query, filter_functions: list[Callable], email: str | None = None, + id: int | None = None, name: str | None = None, ) -> Query: """Apply filtering functions and return filtered results.""" @@ -46,6 +53,7 @@ def apply_user_filter( users: Query = function( users=users, email=email, + id=id, name=name, ) return users diff --git a/trailblazer/store/models.py b/trailblazer/store/models.py index f8016c51..c57ec4fc 100644 --- a/trailblazer/store/models.py +++ b/trailblazer/store/models.py @@ -40,6 +40,7 @@ class User(Model): id = Column(types.Integer, primary_key=True) is_archived = Column(types.Boolean, default=False) name = Column(types.String(128)) + refresh_token = Column(types.Text) runs = orm.relationship("Analysis", backref="user") From 737bb5a19175b979443d891aae66ce6e50eff368 Mon Sep 17 00:00:00 2001 From: Clinical Genomics Bot Date: Mon, 11 Mar 2024 08:51:26 +0000 Subject: [PATCH 2/2] =?UTF-8?q?Bump=20version:=2021.0.8=20=E2=86=92=2021.1?= =?UTF-8?q?.0=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .bumpversion.cfg | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index c72e9763..ccfddafd 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 21.0.8 +current_version = 21.1.0 commit = True tag = True tag_name = {new_version} diff --git a/setup.py b/setup.py index 5de01c31..8d7f0fbc 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ def parse_reqs(req_path="./requirements.txt"): setup( name=NAME, - version="21.0.8", + version="21.1.0", description=DESCRIPTION, long_description=LONG_DESCRIPTION, long_description_content_type="text/markdown",