From 36565bf86d12fe39753605384a410ac0d8ea0d4f Mon Sep 17 00:00:00 2001 From: Guilherme Souza <guilherme@supabase.io> Date: Tue, 25 Feb 2025 15:17:07 -0300 Subject: [PATCH 01/15] feat: add get_claims method --- poetry.lock | 10 ++--- pyproject.toml | 1 + supabase_auth/_async/gotrue_client.py | 49 +++++++++++++++++++++-- supabase_auth/errors.py | 8 ++++ supabase_auth/helpers.py | 56 +++++++++++++++++++++++---- supabase_auth/types.py | 30 ++++++++++++++ 6 files changed, 139 insertions(+), 15 deletions(-) diff --git a/poetry.lock b/poetry.lock index aecdb529..d3c96613 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "annotated-types" @@ -1003,13 +1003,13 @@ urllib3 = ">=1.26.0" [[package]] name = "pyjwt" -version = "2.9.0" +version = "2.10.1" description = "JSON Web Token implementation in Python" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "PyJWT-2.9.0-py3-none-any.whl", hash = "sha256:3b02fb0f44517787776cf48f2ae25d8e14f300e6d7545a4315cee571a415e850"}, - {file = "pyjwt-2.9.0.tar.gz", hash = "sha256:7e1e5b56cc735432a7369cbfa0efe50fa113ebecdc04ae6922deba8b84582d0c"}, + {file = "PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb"}, + {file = "pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953"}, ] [package.dependencies] diff --git a/pyproject.toml b/pyproject.toml index 28dff73c..d6873904 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ classifiers = [ python = "^3.9" httpx = {version = ">=0.26,<0.29", extras = ["http2"]} pydantic = ">=1.10,<3" +pyjwt = "^2.10.1" [tool.poetry.dev-dependencies] pytest = "^8.3.5" diff --git a/supabase_auth/_async/gotrue_client.py b/supabase_auth/_async/gotrue_client.py index 2a2a5564..4069991b 100644 --- a/supabase_auth/_async/gotrue_client.py +++ b/supabase_auth/_async/gotrue_client.py @@ -4,7 +4,7 @@ from functools import partial from json import loads from time import time -from typing import Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple from urllib.parse import parse_qs, urlencode, urlparse from uuid import uuid4 @@ -20,11 +20,12 @@ AuthApiError, AuthImplicitGrantRedirectError, AuthInvalidCredentialsError, + AuthInvalidJwtError, AuthRetryableError, AuthSessionMissingError, ) from ..helpers import ( - decode_jwt_payload, + decode_jwt, generate_pkce_challenge, generate_pkce_verifier, model_dump, @@ -39,6 +40,8 @@ from ..http_clients import AsyncClient from ..timer import Timer from ..types import ( + JWK, + JWKS, AuthChangeEvent, AuthenticatorAssuranceLevels, AuthFlowType, @@ -106,6 +109,7 @@ def __init__( verify=verify, proxy=proxy, ) + self._jwks: JWKS = {} self._storage_key = storage_key or STORAGE_KEY self._auto_refresh_token = auto_refresh_token self._persist_session = persist_session @@ -1128,7 +1132,8 @@ def _decode_jwt(self, jwt: str) -> DecodedJWTDict: """ Decodes a JWT (without performing any validation). """ - return decode_jwt_payload(jwt) + decoded = decode_jwt(jwt) + return decoded["payload"] async def exchange_code_for_session(self, params: CodeExchangeParams): code_verifier = params.get("code_verifier") or await self._storage.get_item( @@ -1150,3 +1155,41 @@ async def exchange_code_for_session(self, params: CodeExchangeParams): await self._save_session(response.session) self._notify_all_subscribers("SIGNED_IN", response.session) return response + + async def _fetch_jwks(self, kid: str, jwks: JWKS) -> JWK: + # try fetching from the suplied keys. + jwk = next((jwk for jwk in jwks["keys"] if jwk["kid"] == kid), None) + + if jwk: + return jwk + + # try fetching from the cache. + jwk = next((jwk for jwk in self._jwks["keys"] if jwk["kid"] == kid), None) + if jwk: + return jwk + + # jwk isn't cached in memory so we need to fetch it from the well-known endpoint + response = await self._request("GET", ".well-known/jwks.json", xform=parse_jwks) + if response.jwks: + self._jwks = response.jwks + + # find the signing key + jwk = next( + (jwk for jwk in response.jwks["keys"] if jwk["kid"] == kid), None + ) + if not jwk: + raise AuthInvalidJwtError("No matching signing key found in JWKS") + + return jwk + + raise AuthInvalidJwtError("JWT has no valid kid") + + async def get_claims(self): + pass + + +def parse_jwks(response: Any) -> JWKS: + if "keys" not in response or len(response.keys) == 0: + raise AuthInvalidJwtError("JWKS is empty") + + return JWKS(keys=response.keys) diff --git a/supabase_auth/errors.py b/supabase_auth/errors.py index fa693894..1cf8de3c 100644 --- a/supabase_auth/errors.py +++ b/supabase_auth/errors.py @@ -225,3 +225,11 @@ def to_dict(self) -> AuthApiErrorDict: "status": self.status, "reasons": self.reasons, } + +class AuthInvalidJwtError(CustomAuthError): + def __init__(self, message: str) -> None: + CustomAuthError.__init__( + self, + message, + "AuthInvalidJwtError", + "invalid_jwt", \ No newline at end of file diff --git a/supabase_auth/helpers.py b/supabase_auth/helpers.py index ae9dc7c5..8be212e2 100644 --- a/supabase_auth/helpers.py +++ b/supabase_auth/helpers.py @@ -8,16 +8,19 @@ from base64 import urlsafe_b64decode from datetime import datetime from json import loads -from typing import Any, Dict, Optional, Type, TypeVar, cast +from typing import Any, Callable, Dict, Literal, Optional, Type, TypeVar, cast from urllib.parse import urlparse from httpx import HTTPStatusError, Response +import jwt +import jwt.algorithms from pydantic import BaseModel from .constants import API_VERSION_HEADER_NAME, API_VERSIONS from .errors import ( AuthApiError, AuthError, + AuthInvalidJwtError, AuthRetryableError, AuthUnknownError, AuthWeakPasswordError, @@ -192,15 +195,38 @@ def handle_exception(exception: Exception) -> AuthError: return AuthUnknownError(get_error_message(error), e) -def decode_jwt_payload(token: str) -> Any: - parts = token.split(".") - if len(parts) != 3: - raise ValueError("JWT is not valid: not a JWT structure") - base64url = parts[1] +def str_from_base64url(base64url: str) -> str: # Addding padding otherwise the following error happens: # binascii.Error: Incorrect padding base64url_with_padding = base64url + "=" * (-len(base64url) % 4) - return loads(urlsafe_b64decode(base64url_with_padding).decode("utf-8")) + return urlsafe_b64decode(base64url_with_padding).decode("utf-8") + + +def base64url_to_bytes(base64url: str) -> bytes: + # Addding padding otherwise the following error happens: + # binascii.Error: Incorrect padding + base64url_with_padding = base64url + "=" * (-len(base64url) % 4) + return urlsafe_b64decode(base64url_with_padding) + + +def decode_jwt(token: str) -> Dict[str, Any]: + parts = token.split(".") + if len(parts) != 3: + raise AuthInvalidJwtError("Invalid JWT structure") + + # regex check for base64url + if not re.match(BASE64URL_REGEX, parts[1]): + raise AuthInvalidJwtError("JWT not in base64url format") + + return { + "header": loads(str_from_base64url(parts[0])), + "payload": loads(str_from_base64url(parts[1])), + "signature": base64url_to_bytes(parts[2]), + "raw": { + "header": parts[0], + "payload": parts[1], + }, + } def generate_pkce_verifier(length=64): @@ -267,3 +293,19 @@ def is_valid_jwt(value: str) -> bool: return False return True + + +def validate_exp(exp: int) -> None: + if not exp: + raise AuthInvalidJwtError("JWT has no expiration time") + + time_now = datetime.now().timestamp() + if exp <= time_now: + raise AuthInvalidJwtError("JWT has expired") + + +def get_algorithm(alg: Literal["RS256", "ES256"]) -> jwt.algorithms.Algorithm: + if alg == "RS256": + return jwt.algorithms.RSAAlgorithm + elif alg == "ES256": + return jwt.algorithms.ECAlgorithm diff --git a/supabase_auth/types.py b/supabase_auth/types.py index 86bda3e2..dd499aa0 100644 --- a/supabase_auth/types.py +++ b/supabase_auth/types.py @@ -789,6 +789,36 @@ class SignOutOptions(TypedDict): scope: NotRequired[SignOutScope] +class JWTHeader(TypedDict): + alg: Literal["RS256", "ES256", "HS256"] + typ: str + kid: str + + +class RequiredClaims(TypedDict): + iss: str + sub: str + auth: Union[str, List[str]] + exp: int + iat: int + role: str + aal: AuthenticatorAssuranceLevels + session_id: str + + +class JWTPayload(RequiredClaims, TypedDict, total=False): + pass + + +class JWK(TypedDict, total=False): + kty: Literal["RSA", "EC", "oct"] + key_ops: List[str] + alg: Optional[str] + kid: Optional[str] + +class JWKS(TypedDict): + keys: List[JWK] + for model in [ AMREntry, AuthResponse, From 023d610f78b097bb8dc2937cc9d5c4ed5c18be51 Mon Sep 17 00:00:00 2001 From: Guilherme Souza <guilherme@supabase.io> Date: Tue, 25 Feb 2025 15:31:01 -0300 Subject: [PATCH 02/15] sync infra --- infra/db/00-schema.sql | 1 + infra/docker-compose.yml | 53 +++++++++++++++++++++++++++++++--------- 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/infra/db/00-schema.sql b/infra/db/00-schema.sql index 229af46f..ba844794 100644 --- a/infra/db/00-schema.sql +++ b/infra/db/00-schema.sql @@ -83,3 +83,4 @@ GRANT ALL PRIVILEGES ON SCHEMA auth TO postgres; GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA auth TO postgres; GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA auth TO postgres; ALTER USER postgres SET search_path = "auth"; + diff --git a/infra/docker-compose.yml b/infra/docker-compose.yml index a4cf04c7..29d4d1ed 100644 --- a/infra/docker-compose.yml +++ b/infra/docker-compose.yml @@ -2,7 +2,7 @@ version: '3' services: gotrue: # Signup enabled, autoconfirm off - image: supabase/auth:v2.132.3 + image: supabase/auth:v2.151.0 ports: - '9999:9999' environment: @@ -32,21 +32,22 @@ services: GOTRUE_SMTP_ADMIN_EMAIL: admin@email.com GOTRUE_MAILER_SUBJECTS_CONFIRMATION: 'Please confirm' GOTRUE_EXTERNAL_PHONE_ENABLED: 'true' - GOTRUE_SMS_PROVIDER: "twilio" - GOTRUE_SMS_TWILIO_ACCOUNT_SID: "${GOTRUE_SMS_TWILIO_ACCOUNT_SID}" - GOTRUE_SMS_TWILIO_AUTH_TOKEN: "${GOTRUE_SMS_TWILIO_AUTH_TOKEN}" - GOTRUE_SMS_TWILIO_MESSAGE_SERVICE_SID: "${GOTRUE_SMS_TWILIO_MESSAGE_SERVICE_SID}" + GOTRUE_SMS_PROVIDER: 'twilio' + GOTRUE_SMS_TWILIO_ACCOUNT_SID: '${GOTRUE_SMS_TWILIO_ACCOUNT_SID}' + GOTRUE_SMS_TWILIO_AUTH_TOKEN: '${GOTRUE_SMS_TWILIO_AUTH_TOKEN}' + GOTRUE_SMS_TWILIO_MESSAGE_SERVICE_SID: '${GOTRUE_SMS_TWILIO_MESSAGE_SERVICE_SID}' GOTRUE_SMS_AUTOCONFIRM: 'false' - GOTRUE_COOKIE_KEY: "sb" + GOTRUE_COOKIE_KEY: 'sb' depends_on: - db restart: on-failure autoconfirm: # Signup enabled, autoconfirm on - image: supabase/auth:v2.132.3 + image: supabase/auth:v2.151.0 ports: - '9998:9998' environment: GOTRUE_JWT_SECRET: '37c304f8-51aa-419a-a1af-06154e63707a' + GOTRUE_JWT_KEYS: '[{"kty":"oct","k":"Z7-AyPyChGNcQsX16cPBV-pPBo4q-zckDxkq1VZjATo","kid":"12580317-221c-49b6-894a-f4473b8afe39","key_ops":["sign", "verify"],"alg":"HS256"},{"kty":"RSA","n":"y3KQnIXK6wkPQ5m0XWp7z54BNZzXJk4IxXy81zFophdBBqz6u5OCMqWkC6i3WB7rlax4xjmxxyGyYRODooqCQTGahmpXryAAKc3g-gDIAq2MqVwlpmvXDavCVRK4hK7DZ6wK4MHrliSNHCuCkwIH3ofxTxgUwpSkOT58iU1ZOua5E1Y6R_Ozt3gLHha0Xa7a4V23pkP7n0xBvJPzIqiS3MZ4CQ_pz-buXYRgCPQkUJvXFFcuxmyqoYzorwQ1YVBOmH2XMx26RrCIxgj7geo9eVQ9u5qCPpQCGV5biqYMC4_m1kurOGf62URGRzXtmVzrW1PZJAeGoqMz5Fcfr8hiwQ","e":"AQAB","d":"C4XxquvpEmbw9mM-VAwz9w58Aw1fIkxJMuZdy9KAmue2RyqFCRrRxQycvgxQVi1qKpAaRx_9ccn20IjKa-psdkTY-8QKM2EcoUGH_KEOsxghX3ZYq5RwGdYgq7DjwqAjcTvNYe2Z6mcnlvDf9HOo_nG0uUYj5uGEa7meVCiNZUiSVdNGs-vOTUD8yB5pbZ4ute8ebuUzCWGQ3YwSNoWLa-dbECSO7jeobCapdB52MjEwE3_Ii8BWoySeDP-DEFX_5RTM2Zeh81zXAgmOxpZYTkjMsrznyxxBbXn7CdT8WMEXrreGZwIt3Mu6XpsLF5mwmTQ_ZyoM6tJpn5LeAhnCAQ","p":"6xy1skrnlrGUWtZFSHixn_eRA_O3GXKNBE4wziWodGZaFYsmFijZHbuQT0WFqc0epvLHNdNPvubFrVfV-U7ZIarfSSq6qBwBzDrDQS060MvjJIjrI16pKlx2X727FR1ZuwxT27dNg-wRTgKcZqXEalkvFOTEYBlCtw2-vzI0aRs","q":"3YWwOAs4GRZ9eq_fqNujACWJFyUO9QgEDPDOMg0EZhY7WkAlehTxxVXg65spWnfx_0GSc72I5N5qdbY-yDh2Dl7zIxvwnqZaKMJn4PEFkeAfyg62XlJlkHIwOVSj6vLNUDdDmG7bO2k6MyQ59jeuAemIljf9WhALNy8c9R0K3VM","dp":"KJ4LHcQnAjeng5Hk4kJHnXUtjls6VKEfj5DaiaKj2YgdI_-oEsf3ylUu9yLxloYjN4BVvgzFiBtiJzI3exyOEmzsqj1Bhe1guiGkvcvMj2nJ0fP9e1zNKM5UfPHQMjOh3tigXCLst0-_JZT55BnbNuw1YAytiFSU2_755xoLR-U","dq":"dCP7V-bJ6p1X_FLpOGau9wy262OKi_0_4mj-Mk-Q1tUhGRg4jeEdQRDdc6lN7Rilz-ZZGkVs2FGkD0MVd3PisXYmk2m6pfMhoe0K-WxkNy8Ce7Vq99jLVwgHMIenyS6zZjMTRYAZgPSShu2fVe-rU2VVLyz7r5RpzOzuibRIVfE","qi":"i7ND2teiVLkbaAs6rHfo5DiD1nlsORNYnn8Y_FjF6utb5OUljZ6-5WyEDJN9oIUX8o_Il9E6js-z7nhvPfFZHQN7ZWuYI0rO5qmsCDS9jWJ4GR61SgzZuLT7Jpp_KtwjW70x5wZ1Y-GugOP1Wct1YZWHn5YyLhvO6X_vttSmcS0","kid":"638c54b8-28c2-4b12-9598-ba12ef610a29","key_ops":["verify"],"alg":"RS256"}]' GOTRUE_JWT_EXP: 3600 GOTRUE_DB_DRIVER: postgres DB_NAMESPACE: auth @@ -66,12 +67,42 @@ services: GOTRUE_SMTP_USER: GOTRUE_SMTP_USER GOTRUE_SMTP_PASS: GOTRUE_SMTP_PASS GOTRUE_SMTP_ADMIN_EMAIL: admin@email.com - GOTRUE_COOKIE_KEY: "sb" + GOTRUE_COOKIE_KEY: 'sb' + depends_on: + - db + restart: on-failure + autoconfirm_with_asymmetric_keys: # Signup enabled, autoconfirm on + image: supabase/auth:v2.169.0 + ports: + - '9996:9996' + environment: + GOTRUE_JWT_SECRET: 'Z7-AyPyChGNcQsX16cPBV-pPBo4q-zckDxkq1VZjATo' + GOTRUE_JWT_KEYS: '[{"kty":"oct","k":"Z7-AyPyChGNcQsX16cPBV-pPBo4q-zckDxkq1VZjATo","kid":"12580317-221c-49b6-894a-f4473b8afe39","key_ops":["verify"],"alg":"HS256"},{"kty":"RSA","n":"y3KQnIXK6wkPQ5m0XWp7z54BNZzXJk4IxXy81zFophdBBqz6u5OCMqWkC6i3WB7rlax4xjmxxyGyYRODooqCQTGahmpXryAAKc3g-gDIAq2MqVwlpmvXDavCVRK4hK7DZ6wK4MHrliSNHCuCkwIH3ofxTxgUwpSkOT58iU1ZOua5E1Y6R_Ozt3gLHha0Xa7a4V23pkP7n0xBvJPzIqiS3MZ4CQ_pz-buXYRgCPQkUJvXFFcuxmyqoYzorwQ1YVBOmH2XMx26RrCIxgj7geo9eVQ9u5qCPpQCGV5biqYMC4_m1kurOGf62URGRzXtmVzrW1PZJAeGoqMz5Fcfr8hiwQ","e":"AQAB","d":"C4XxquvpEmbw9mM-VAwz9w58Aw1fIkxJMuZdy9KAmue2RyqFCRrRxQycvgxQVi1qKpAaRx_9ccn20IjKa-psdkTY-8QKM2EcoUGH_KEOsxghX3ZYq5RwGdYgq7DjwqAjcTvNYe2Z6mcnlvDf9HOo_nG0uUYj5uGEa7meVCiNZUiSVdNGs-vOTUD8yB5pbZ4ute8ebuUzCWGQ3YwSNoWLa-dbECSO7jeobCapdB52MjEwE3_Ii8BWoySeDP-DEFX_5RTM2Zeh81zXAgmOxpZYTkjMsrznyxxBbXn7CdT8WMEXrreGZwIt3Mu6XpsLF5mwmTQ_ZyoM6tJpn5LeAhnCAQ","p":"6xy1skrnlrGUWtZFSHixn_eRA_O3GXKNBE4wziWodGZaFYsmFijZHbuQT0WFqc0epvLHNdNPvubFrVfV-U7ZIarfSSq6qBwBzDrDQS060MvjJIjrI16pKlx2X727FR1ZuwxT27dNg-wRTgKcZqXEalkvFOTEYBlCtw2-vzI0aRs","q":"3YWwOAs4GRZ9eq_fqNujACWJFyUO9QgEDPDOMg0EZhY7WkAlehTxxVXg65spWnfx_0GSc72I5N5qdbY-yDh2Dl7zIxvwnqZaKMJn4PEFkeAfyg62XlJlkHIwOVSj6vLNUDdDmG7bO2k6MyQ59jeuAemIljf9WhALNy8c9R0K3VM","dp":"KJ4LHcQnAjeng5Hk4kJHnXUtjls6VKEfj5DaiaKj2YgdI_-oEsf3ylUu9yLxloYjN4BVvgzFiBtiJzI3exyOEmzsqj1Bhe1guiGkvcvMj2nJ0fP9e1zNKM5UfPHQMjOh3tigXCLst0-_JZT55BnbNuw1YAytiFSU2_755xoLR-U","dq":"dCP7V-bJ6p1X_FLpOGau9wy262OKi_0_4mj-Mk-Q1tUhGRg4jeEdQRDdc6lN7Rilz-ZZGkVs2FGkD0MVd3PisXYmk2m6pfMhoe0K-WxkNy8Ce7Vq99jLVwgHMIenyS6zZjMTRYAZgPSShu2fVe-rU2VVLyz7r5RpzOzuibRIVfE","qi":"i7ND2teiVLkbaAs6rHfo5DiD1nlsORNYnn8Y_FjF6utb5OUljZ6-5WyEDJN9oIUX8o_Il9E6js-z7nhvPfFZHQN7ZWuYI0rO5qmsCDS9jWJ4GR61SgzZuLT7Jpp_KtwjW70x5wZ1Y-GugOP1Wct1YZWHn5YyLhvO6X_vttSmcS0","kid":"638c54b8-28c2-4b12-9598-ba12ef610a29","key_ops":["sign","verify"],"alg":"RS256"}]' + GOTRUE_JWT_EXP: 3600 + GOTRUE_DB_DRIVER: postgres + DB_NAMESPACE: auth + GOTRUE_API_HOST: 0.0.0.0 + PORT: 9996 + GOTRUE_DISABLE_SIGNUP: 'false' + API_EXTERNAL_URL: http://localhost:9996 + GOTRUE_SITE_URL: http://localhost:9996 + GOTRUE_MAILER_AUTOCONFIRM: 'true' + GOTRUE_SMS_AUTOCONFIRM: 'true' + GOTRUE_LOG_LEVEL: DEBUG + GOTRUE_OPERATOR_TOKEN: super-secret-operator-token + DATABASE_URL: 'postgres://postgres:postgres@db:5432/postgres?sslmode=disable' + GOTRUE_EXTERNAL_PHONE_ENABLED: 'true' + GOTRUE_SMTP_HOST: mail + GOTRUE_SMTP_PORT: 2500 + GOTRUE_SMTP_USER: GOTRUE_SMTP_USER + GOTRUE_SMTP_PASS: GOTRUE_SMTP_PASS + GOTRUE_SMTP_ADMIN_EMAIL: admin@email.com + GOTRUE_COOKIE_KEY: 'sb' depends_on: - db restart: on-failure disabled: # Signup disabled - image: supabase/auth:v2.132.3 + image: supabase/auth:v2.151.0 ports: - '9997:9997' environment: @@ -95,7 +126,7 @@ services: GOTRUE_SMTP_USER: GOTRUE_SMTP_USER GOTRUE_SMTP_PASS: GOTRUE_SMTP_PASS GOTRUE_SMTP_ADMIN_EMAIL: admin@email.com - GOTRUE_COOKIE_KEY: "sb" + GOTRUE_COOKIE_KEY: 'sb' depends_on: - db restart: on-failure @@ -106,7 +137,7 @@ services: - '9000:9000' # web interface - '1100:1100' # POP3 db: - image: supabase/postgres:15.1.1.66 + image: supabase/postgres:15.1.1.46 ports: - '5432:5432' command: postgres -c config_file=/etc/postgresql/postgresql.conf From faaf88c654f4358c870c2bd268d51771184c9968 Mon Sep 17 00:00:00 2001 From: Guilherme Souza <guilherme@supabase.io> Date: Tue, 25 Feb 2025 15:31:11 -0300 Subject: [PATCH 03/15] fix tests --- supabase_auth/errors.py | 3 ++- supabase_auth/helpers.py | 2 +- tests/test_helpers.py | 12 ++++++------ 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/supabase_auth/errors.py b/supabase_auth/errors.py index 1cf8de3c..6764aa82 100644 --- a/supabase_auth/errors.py +++ b/supabase_auth/errors.py @@ -232,4 +232,5 @@ def __init__(self, message: str) -> None: self, message, "AuthInvalidJwtError", - "invalid_jwt", \ No newline at end of file + "invalid_jwt", + ) \ No newline at end of file diff --git a/supabase_auth/helpers.py b/supabase_auth/helpers.py index 8be212e2..838d16d5 100644 --- a/supabase_auth/helpers.py +++ b/supabase_auth/helpers.py @@ -8,7 +8,7 @@ from base64 import urlsafe_b64decode from datetime import datetime from json import loads -from typing import Any, Callable, Dict, Literal, Optional, Type, TypeVar, cast +from typing import Any, Dict, Literal, Optional, Type, TypeVar, cast from urllib.parse import urlparse from httpx import HTTPStatusError, Response diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 57dcb66a..4132e9e6 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -6,9 +6,9 @@ from httpx import Headers, Response from supabase_auth.constants import API_VERSION_HEADER_NAME -from supabase_auth.errors import AuthApiError, AuthWeakPasswordError +from supabase_auth.errors import AuthApiError, AuthInvalidJwtError, AuthWeakPasswordError from supabase_auth.helpers import ( - decode_jwt_payload, + decode_jwt, generate_pkce_challenge, generate_pkce_verifier, get_error_code, @@ -111,13 +111,13 @@ def test_get_error_code(): assert get_error_code({"error_code": "500"}) == "500" -def test_decode_jwt_payload(): - assert decode_jwt_payload(mock_access_token()) +def test_decode_jwt(): + assert decode_jwt(mock_access_token()) with pytest.raises( - ValueError, match=r"JWT is not valid: not a JWT structure" + AuthInvalidJwtError, match=r"Invalid JWT structure" ) as exc: - decode_jwt_payload("non-valid-jwt") + decode_jwt("non-valid-jwt") assert exc.value is not None From 8ac31322ed0111fd81ae9d5ae9d2a521ca70be53 Mon Sep 17 00:00:00 2001 From: Guilherme Souza <guilherme@supabase.io> Date: Tue, 25 Feb 2025 15:51:35 -0300 Subject: [PATCH 04/15] fix decode_jwt tests --- Makefile | 2 +- poetry.lock | 60 +++++++++++++++-------- pyproject.toml | 6 +-- supabase_auth/_sync/gotrue_client.py | 61 ++++++++++++++++++++--- supabase_auth/errors.py | 1 + supabase_auth/helpers.py | 32 +++++++----- tests/_sync/test_gotrue_admin_api.py | 73 +++++++++++++++------------- 7 files changed, 161 insertions(+), 74 deletions(-) diff --git a/Makefile b/Makefile index 69857fde..5289f26c 100644 --- a/Makefile +++ b/Makefile @@ -38,7 +38,7 @@ build_run_tests: build_sync run_tests echo "Done" sleep: - sleep 20 + sleep 3 rename_project: rename_package_dir rename_package diff --git a/poetry.lock b/poetry.lock index d3c96613..d6f729f7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1221,18 +1221,19 @@ httpx = ">=0.25.0" [[package]] name = "setuptools" -version = "58.5.3" +version = "72.2.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" files = [ - {file = "setuptools-58.5.3-py3-none-any.whl", hash = "sha256:a481fbc56b33f5d8f6b33dce41482e64c68b668be44ff42922903b03872590bf"}, - {file = "setuptools-58.5.3.tar.gz", hash = "sha256:dae6b934a965c8a59d6d230d3867ec408bb95e73bd538ff77e71fedf1eaca729"}, + {file = "setuptools-72.2.0-py3-none-any.whl", hash = "sha256:f11dd94b7bae3a156a95ec151f24e4637fb4fa19c878e4d191bfb8b2d82728c4"}, + {file = "setuptools-72.2.0.tar.gz", hash = "sha256:80aacbf633704e9c8bfa1d99fa5dd4dc59573efcf9e4042c13d3bcef91ac2ef9"}, ] [package.extras] -docs = ["furo", "jaraco.packaging (>=8.2)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx", "sphinx-inline-tabs", "sphinxcontrib-towncrier"] -testing = ["flake8-2020", "jaraco.envs", "jaraco.path (>=3.2.0)", "mock", "paver", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.0.1)", "pytest-flake8", "pytest-mypy", "pytest-virtualenv (>=1.2.7)", "pytest-xdist", "sphinx", "virtualenv (>=13.0.0)", "wheel"] +core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.text (>=3.7)", "more-itertools (>=8.8)", "ordered-set (>=3.1.1)", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "mypy (==1.11.*)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (<0.4)", "pytest-ruff (>=0.2.1)", "pytest-ruff (>=0.3.2)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] [[package]] name = "sniffio" @@ -1245,6 +1246,17 @@ files = [ {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, ] +[[package]] +name = "tokenize-rt" +version = "6.1.0" +description = "A wrapper around the stdlib `tokenize` which roundtrips." +optional = false +python-versions = ">=3.9" +files = [ + {file = "tokenize_rt-6.1.0-py2.py3-none-any.whl", hash = "sha256:d706141cdec4aa5f358945abe36b911b8cbdc844545da99e811250c0cee9b6fc"}, + {file = "tokenize_rt-6.1.0.tar.gz", hash = "sha256:e8ee836616c0877ab7c7b54776d2fefcc3bde714449a206762425ae114b53c86"}, +] + [[package]] name = "tomli" version = "2.0.2" @@ -1300,30 +1312,38 @@ files = [ [[package]] name = "unasync" -version = "0.5.0" +version = "0.6.0" description = "The async transformation code." optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4" +python-versions = ">=3.8" files = [ - {file = "unasync-0.5.0-py3-none-any.whl", hash = "sha256:8d4536dae85e87b8751dfcc776f7656fd0baf54bb022a7889440dc1b9dc3becb"}, - {file = "unasync-0.5.0.tar.gz", hash = "sha256:b675d87cf56da68bd065d3b7a67ac71df85591978d84c53083c20d79a7e5096d"}, + {file = "unasync-0.6.0-py3-none-any.whl", hash = "sha256:9cf7aaaea9737e417d8949bf9be55dc25fdb4ef1f4edc21b58f76ff0d2b9d73f"}, + {file = "unasync-0.6.0.tar.gz", hash = "sha256:a9d01ace3e1068b20550ab15b7f9723b15b8bcde728bc1770bcb578374c7ee58"}, ] +[package.dependencies] +setuptools = "*" +tokenize-rt = "*" + [[package]] name = "unasync-cli" -version = "0.0.9" -description = "Command line interface for unasync" +version = "0.0.1" +description = "Command line interface for unasync. Fork of https://github.com/leynier/unasync-cli/" optional = false -python-versions = ">=3.6.14,<4.0.0" -files = [ - {file = "unasync-cli-0.0.9.tar.gz", hash = "sha256:ca9d8c57ebb68911f8f8f68f243c7f6d0bb246ee3fd14743bc51c8317e276554"}, - {file = "unasync_cli-0.0.9-py3-none-any.whl", hash = "sha256:f96c42fb2862efa555ce6d6415a5983ceb162aa0e45be701656d20a955c7c540"}, -] +python-versions = "^3.8.18" +files = [] +develop = false [package.dependencies] -setuptools = ">=58.2.0,<59.0.0" -typer = ">=0.4.0,<0.5.0" -unasync = ">=0.5.0,<0.6.0" +setuptools = "^72.1.0" +typer = "^0.4.0" +unasync = "^0.6.0" + +[package.source] +type = "git" +url = "https://github.com/supabase-community/unasync-cli.git" +reference = "main" +resolved_reference = "6a082ee36d5e8941622b70f6cbcaf8e7a5be339d" [[package]] name = "urllib3" diff --git a/pyproject.toml b/pyproject.toml index d6873904..09c7c90d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,12 +11,12 @@ license = "MIT" classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent" + "Operating System :: OS Independent", ] [tool.poetry.dependencies] python = "^3.9" -httpx = {version = ">=0.26,<0.29", extras = ["http2"]} +httpx = { version = ">=0.26,<0.29", extras = ["http2"] } pydantic = ">=1.10,<3" pyjwt = "^2.10.1" @@ -30,7 +30,7 @@ pytest-cov = "^6.0.0" pytest-depends = "^1.0.1" pytest-asyncio = "^0.25.3" Faker = "^36.1.1" -unasync-cli = "^0.0.9" +unasync-cli = { git = "https://github.com/supabase-community/unasync-cli.git", branch = "main" } [tool.poetry.group.dev.dependencies] pygithub = ">=1.57,<3.0" diff --git a/supabase_auth/_sync/gotrue_client.py b/supabase_auth/_sync/gotrue_client.py index e4e821ef..db6040b0 100644 --- a/supabase_auth/_sync/gotrue_client.py +++ b/supabase_auth/_sync/gotrue_client.py @@ -4,7 +4,7 @@ from functools import partial from json import loads from time import time -from typing import Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple from urllib.parse import parse_qs, urlencode, urlparse from uuid import uuid4 @@ -20,11 +20,12 @@ AuthApiError, AuthImplicitGrantRedirectError, AuthInvalidCredentialsError, + AuthInvalidJwtError, AuthRetryableError, AuthSessionMissingError, ) from ..helpers import ( - decode_jwt_payload, + decode_jwt, generate_pkce_challenge, generate_pkce_verifier, model_dump, @@ -39,6 +40,8 @@ from ..http_clients import SyncClient from ..timer import Timer from ..types import ( + JWK, + JWKS, AuthChangeEvent, AuthenticatorAssuranceLevels, AuthFlowType, @@ -106,6 +109,7 @@ def __init__( verify=verify, proxy=proxy, ) + self._jwks: JWKS = {} self._storage_key = storage_key or STORAGE_KEY self._auto_refresh_token = auto_refresh_token self._persist_session = persist_session @@ -415,7 +419,9 @@ def sign_in_with_oauth( ) return OAuthResponse(provider=provider, url=url_with_qs) - def link_identity(self, credentials: SignInWithOAuthCredentials) -> OAuthResponse: + def link_identity( + self, credentials: SignInWithOAuthCredentials + ) -> OAuthResponse: provider = credentials.get("provider") options = credentials.get("options", {}) redirect_to = options.get("redirect_to") @@ -698,7 +704,9 @@ def set_session(self, access_token: str, refresh_token: str) -> AuthResponse: self._notify_all_subscribers("TOKEN_REFRESHED", session) return AuthResponse(session=session, user=response.user) - def refresh_session(self, refresh_token: Optional[str] = None) -> AuthResponse: + def refresh_session( + self, refresh_token: Optional[str] = None + ) -> AuthResponse: """ Returns a new session, regardless of expiry status. @@ -1107,7 +1115,9 @@ def _get_url_for_provider( if self._flow_type == "pkce": code_verifier = generate_pkce_verifier() code_challenge = generate_pkce_challenge(code_verifier) - self._storage.set_item(f"{self._storage_key}-code-verifier", code_verifier) + self._storage.set_item( + f"{self._storage_key}-code-verifier", code_verifier + ) code_challenge_method = ( "plain" if code_verifier == code_challenge else "s256" ) @@ -1122,7 +1132,8 @@ def _decode_jwt(self, jwt: str) -> DecodedJWTDict: """ Decodes a JWT (without performing any validation). """ - return decode_jwt_payload(jwt) + decoded = decode_jwt(jwt) + return decoded["payload"] def exchange_code_for_session(self, params: CodeExchangeParams): code_verifier = params.get("code_verifier") or self._storage.get_item( @@ -1144,3 +1155,41 @@ def exchange_code_for_session(self, params: CodeExchangeParams): self._save_session(response.session) self._notify_all_subscribers("SIGNED_IN", response.session) return response + + def _fetch_jwks(self, kid: str, jwks: JWKS) -> JWK: + # try fetching from the suplied keys. + jwk = next((jwk for jwk in jwks["keys"] if jwk["kid"] == kid), None) + + if jwk: + return jwk + + # try fetching from the cache. + jwk = next((jwk for jwk in self._jwks["keys"] if jwk["kid"] == kid), None) + if jwk: + return jwk + + # jwk isn't cached in memory so we need to fetch it from the well-known endpoint + response = self._request("GET", ".well-known/jwks.json", xform=parse_jwks) + if response.jwks: + self._jwks = response.jwks + + # find the signing key + jwk = next( + (jwk for jwk in response.jwks["keys"] if jwk["kid"] == kid), None + ) + if not jwk: + raise AuthInvalidJwtError("No matching signing key found in JWKS") + + return jwk + + raise AuthInvalidJwtError("JWT has no valid kid") + + def get_claims(self): + pass + + +def parse_jwks(response: Any) -> JWKS: + if "keys" not in response or len(response.keys) == 0: + raise AuthInvalidJwtError("JWKS is empty") + + return JWKS(keys=response.keys) diff --git a/supabase_auth/errors.py b/supabase_auth/errors.py index 6764aa82..a26f939b 100644 --- a/supabase_auth/errors.py +++ b/supabase_auth/errors.py @@ -232,5 +232,6 @@ def __init__(self, message: str) -> None: self, message, "AuthInvalidJwtError", + 400, "invalid_jwt", ) \ No newline at end of file diff --git a/supabase_auth/helpers.py b/supabase_auth/helpers.py index 838d16d5..1a7c48fb 100644 --- a/supabase_auth/helpers.py +++ b/supabase_auth/helpers.py @@ -8,7 +8,7 @@ from base64 import urlsafe_b64decode from datetime import datetime from json import loads -from typing import Any, Dict, Literal, Optional, Type, TypeVar, cast +from typing import Any, Dict, Literal, Optional, Type, TypeVar, TypedDict, cast from urllib.parse import urlparse from httpx import HTTPStatusError, Response @@ -30,6 +30,8 @@ AuthResponse, GenerateLinkProperties, GenerateLinkResponse, + JWTHeader, + JWTPayload, LinkIdentityResponse, Session, SSOResponse, @@ -209,24 +211,32 @@ def base64url_to_bytes(base64url: str) -> bytes: return urlsafe_b64decode(base64url_with_padding) -def decode_jwt(token: str) -> Dict[str, Any]: +class DecodedJWT(TypedDict): + header: JWTHeader + payload: JWTPayload + signature: bytes + raw: Dict[str, str] + + +def decode_jwt(token: str) -> DecodedJWT: parts = token.split(".") if len(parts) != 3: raise AuthInvalidJwtError("Invalid JWT structure") # regex check for base64url - if not re.match(BASE64URL_REGEX, parts[1]): - raise AuthInvalidJwtError("JWT not in base64url format") - - return { - "header": loads(str_from_base64url(parts[0])), - "payload": loads(str_from_base64url(parts[1])), - "signature": base64url_to_bytes(parts[2]), - "raw": { + # for part in parts: + # if not re.match(BASE64URL_REGEX, part): + # raise AuthInvalidJwtError("JWT not in base64url format") + + return DecodedJWT( + header=JWTHeader(**loads(str_from_base64url(parts[0]))), + payload=JWTPayload(**loads(str_from_base64url(parts[1]))), + signature=base64url_to_bytes(parts[2]), + raw={ "header": parts[0], "payload": parts[1], }, - } + ) def generate_pkce_verifier(length=64): diff --git a/tests/_sync/test_gotrue_admin_api.py b/tests/_sync/test_gotrue_admin_api.py index c885d4bf..c9769f1f 100644 --- a/tests/_sync/test_gotrue_admin_api.py +++ b/tests/_sync/test_gotrue_admin_api.py @@ -17,7 +17,6 @@ ) from .utils import ( create_new_user_with_email, - mock_access_token, mock_app_metadata, mock_user_credentials, mock_user_metadata, @@ -31,19 +30,19 @@ def test_create_user_should_create_a_new_user(): assert response.email == credentials.get("email") -def test_create_user_with_app_metadata(): - app_metadata = mock_app_metadata() +def test_create_user_with_user_metadata(): + user_metadata = mock_user_metadata() credentials = mock_user_credentials() response = service_role_api_client().create_user( { "email": credentials.get("email"), "password": credentials.get("password"), - "app_metadata": app_metadata, + "user_metadata": user_metadata, } ) assert response.user.email == credentials.get("email") - assert "provider" in response.user.app_metadata - assert "providers" in response.user.app_metadata + assert response.user.user_metadata == user_metadata + assert "profile_image" in response.user.user_metadata def test_create_user_with_user_and_app_metadata(): @@ -157,7 +156,7 @@ def test_modify_confirm_email_using_update_user_by_id(): def test_invalid_credential_sign_in_with_phone(): try: - client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password( + response = client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password( { "phone": "+123456789", "password": "strong_pwd", @@ -169,7 +168,7 @@ def test_invalid_credential_sign_in_with_phone(): def test_invalid_credential_sign_in_with_email(): try: - client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password( + response = client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password( { "email": "unknown_user@unknowndomain.com", "password": "strong_pwd", @@ -364,39 +363,47 @@ def test_verify_otp_with_invalid_phone_number(): assert e.message == "Invalid phone number format (E.164 required)" -def test_sign_in_with_oauth(): - assert client_api_auto_confirm_off_signups_enabled_client().sign_in_with_oauth( - { - "provider": "google", - } - ) - +def test_sign_in_with_id_token(): + try: + client_api_auto_confirm_off_signups_enabled_client().sign_in_with_id_token( + { + "provider": "google", + "token": "123456", + } + ) + except AuthApiError as e: + assert e.to_dict() -def test_decode_jwt(): - assert auth_client_with_session()._decode_jwt(mock_access_token()) +def test_sign_in_with_sso(): + with pytest.raises(AuthApiError, match=r"SAML 2.0 is disabled") as exc: + client_api_auto_confirm_off_signups_enabled_client().sign_in_with_sso( + { + "domain": "google", + } + ) + assert exc.value is not None -def test_link_identity_missing_session(): - with pytest.raises(AuthSessionMissingError) as exc: - client_api_auto_confirm_off_signups_enabled_client().link_identity( +def test_sign_in_with_oauth(): + assert ( + client_api_auto_confirm_off_signups_enabled_client().sign_in_with_oauth( { "provider": "google", } ) - assert exc.value is not None + ) -def test_sign_in_with_id_token(): - try: - client_api_auto_confirm_off_signups_enabled_client().sign_in_with_id_token( +def test_link_identity_missing_session(): + + with pytest.raises(AuthSessionMissingError) as exc: + client_api_auto_confirm_off_signups_enabled_client().link_identity( { "provider": "google", - "token": "123456", } ) - except AuthApiError as e: - assert e.to_dict() + assert exc.value is not None def test_get_item_from_memory_storage(): @@ -518,7 +525,7 @@ def test_get_user_identities(): "password": credentials.get("password"), } ) - assert client.get_user_identities().identities[0].identity_data[ + assert (client.get_user_identities()).identities[0].identity_data[ "email" ] == credentials.get("email") @@ -542,19 +549,19 @@ def test_update_user(): ) -def test_create_user_with_user_metadata(): - user_metadata = mock_user_metadata() +def test_create_user_with_app_metadata(): + app_metadata = mock_app_metadata() credentials = mock_user_credentials() response = service_role_api_client().create_user( { "email": credentials.get("email"), "password": credentials.get("password"), - "user_metadata": user_metadata, + "app_metadata": app_metadata, } ) assert response.user.email == credentials.get("email") - assert response.user.user_metadata == user_metadata - assert "profile_image" in response.user.user_metadata + assert "provider" in response.user.app_metadata + assert "providers" in response.user.app_metadata def test_weak_email_password_error(): From 66fa83a14d64a9650d7b71427c929460a14a6bda Mon Sep 17 00:00:00 2001 From: Guilherme Souza <guilherme@supabase.io> Date: Tue, 25 Feb 2025 15:56:17 -0300 Subject: [PATCH 05/15] fix tests returning another error message --- infra/db/00-schema.sql | 1 - supabase_auth/_sync/gotrue_client.py | 12 +++------ supabase_auth/errors.py | 3 ++- supabase_auth/helpers.py | 4 +-- supabase_auth/types.py | 2 ++ tests/_async/test_gotrue_admin_api.py | 2 +- tests/_sync/test_gotrue_admin_api.py | 36 ++++++++++++++------------- tests/test_helpers.py | 10 +++++--- 8 files changed, 35 insertions(+), 35 deletions(-) diff --git a/infra/db/00-schema.sql b/infra/db/00-schema.sql index ba844794..229af46f 100644 --- a/infra/db/00-schema.sql +++ b/infra/db/00-schema.sql @@ -83,4 +83,3 @@ GRANT ALL PRIVILEGES ON SCHEMA auth TO postgres; GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA auth TO postgres; GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA auth TO postgres; ALTER USER postgres SET search_path = "auth"; - diff --git a/supabase_auth/_sync/gotrue_client.py b/supabase_auth/_sync/gotrue_client.py index db6040b0..6e1cbed4 100644 --- a/supabase_auth/_sync/gotrue_client.py +++ b/supabase_auth/_sync/gotrue_client.py @@ -419,9 +419,7 @@ def sign_in_with_oauth( ) return OAuthResponse(provider=provider, url=url_with_qs) - def link_identity( - self, credentials: SignInWithOAuthCredentials - ) -> OAuthResponse: + def link_identity(self, credentials: SignInWithOAuthCredentials) -> OAuthResponse: provider = credentials.get("provider") options = credentials.get("options", {}) redirect_to = options.get("redirect_to") @@ -704,9 +702,7 @@ def set_session(self, access_token: str, refresh_token: str) -> AuthResponse: self._notify_all_subscribers("TOKEN_REFRESHED", session) return AuthResponse(session=session, user=response.user) - def refresh_session( - self, refresh_token: Optional[str] = None - ) -> AuthResponse: + def refresh_session(self, refresh_token: Optional[str] = None) -> AuthResponse: """ Returns a new session, regardless of expiry status. @@ -1115,9 +1111,7 @@ def _get_url_for_provider( if self._flow_type == "pkce": code_verifier = generate_pkce_verifier() code_challenge = generate_pkce_challenge(code_verifier) - self._storage.set_item( - f"{self._storage_key}-code-verifier", code_verifier - ) + self._storage.set_item(f"{self._storage_key}-code-verifier", code_verifier) code_challenge_method = ( "plain" if code_verifier == code_challenge else "s256" ) diff --git a/supabase_auth/errors.py b/supabase_auth/errors.py index a26f939b..cc85b87e 100644 --- a/supabase_auth/errors.py +++ b/supabase_auth/errors.py @@ -226,6 +226,7 @@ def to_dict(self) -> AuthApiErrorDict: "reasons": self.reasons, } + class AuthInvalidJwtError(CustomAuthError): def __init__(self, message: str) -> None: CustomAuthError.__init__( @@ -234,4 +235,4 @@ def __init__(self, message: str) -> None: "AuthInvalidJwtError", 400, "invalid_jwt", - ) \ No newline at end of file + ) diff --git a/supabase_auth/helpers.py b/supabase_auth/helpers.py index 1a7c48fb..57cecf31 100644 --- a/supabase_auth/helpers.py +++ b/supabase_auth/helpers.py @@ -8,12 +8,12 @@ from base64 import urlsafe_b64decode from datetime import datetime from json import loads -from typing import Any, Dict, Literal, Optional, Type, TypeVar, TypedDict, cast +from typing import Any, Dict, Literal, Optional, Type, TypedDict, TypeVar, cast from urllib.parse import urlparse -from httpx import HTTPStatusError, Response import jwt import jwt.algorithms +from httpx import HTTPStatusError, Response from pydantic import BaseModel from .constants import API_VERSION_HEADER_NAME, API_VERSIONS diff --git a/supabase_auth/types.py b/supabase_auth/types.py index dd499aa0..5425173a 100644 --- a/supabase_auth/types.py +++ b/supabase_auth/types.py @@ -816,9 +816,11 @@ class JWK(TypedDict, total=False): alg: Optional[str] kid: Optional[str] + class JWKS(TypedDict): keys: List[JWK] + for model in [ AMREntry, AuthResponse, diff --git a/tests/_async/test_gotrue_admin_api.py b/tests/_async/test_gotrue_admin_api.py index 701f46c8..dc96a145 100644 --- a/tests/_async/test_gotrue_admin_api.py +++ b/tests/_async/test_gotrue_admin_api.py @@ -344,7 +344,7 @@ async def test_verify_otp_with_non_existent_phone_number(): ) assert False except AuthError as e: - assert e.message == "User not found" + assert e.message == "Token has expired or is invalid" async def test_verify_otp_with_invalid_phone_number(): diff --git a/tests/_sync/test_gotrue_admin_api.py b/tests/_sync/test_gotrue_admin_api.py index c9769f1f..dd5fa523 100644 --- a/tests/_sync/test_gotrue_admin_api.py +++ b/tests/_sync/test_gotrue_admin_api.py @@ -156,11 +156,13 @@ def test_modify_confirm_email_using_update_user_by_id(): def test_invalid_credential_sign_in_with_phone(): try: - response = client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password( - { - "phone": "+123456789", - "password": "strong_pwd", - } + response = ( + client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password( + { + "phone": "+123456789", + "password": "strong_pwd", + } + ) ) except AuthApiError as e: assert e.to_dict() @@ -168,11 +170,13 @@ def test_invalid_credential_sign_in_with_phone(): def test_invalid_credential_sign_in_with_email(): try: - response = client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password( - { - "email": "unknown_user@unknowndomain.com", - "password": "strong_pwd", - } + response = ( + client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password( + { + "email": "unknown_user@unknowndomain.com", + "password": "strong_pwd", + } + ) ) except AuthApiError as e: assert e.to_dict() @@ -344,7 +348,7 @@ def test_verify_otp_with_non_existent_phone_number(): ) assert False except AuthError as e: - assert e.message == "User not found" + assert e.message == "Token has expired or is invalid" def test_verify_otp_with_invalid_phone_number(): @@ -386,12 +390,10 @@ def test_sign_in_with_sso(): def test_sign_in_with_oauth(): - assert ( - client_api_auto_confirm_off_signups_enabled_client().sign_in_with_oauth( - { - "provider": "google", - } - ) + assert client_api_auto_confirm_off_signups_enabled_client().sign_in_with_oauth( + { + "provider": "google", + } ) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 4132e9e6..7fb27809 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -6,7 +6,11 @@ from httpx import Headers, Response from supabase_auth.constants import API_VERSION_HEADER_NAME -from supabase_auth.errors import AuthApiError, AuthInvalidJwtError, AuthWeakPasswordError +from supabase_auth.errors import ( + AuthApiError, + AuthInvalidJwtError, + AuthWeakPasswordError, +) from supabase_auth.helpers import ( decode_jwt, generate_pkce_challenge, @@ -114,9 +118,7 @@ def test_get_error_code(): def test_decode_jwt(): assert decode_jwt(mock_access_token()) - with pytest.raises( - AuthInvalidJwtError, match=r"Invalid JWT structure" - ) as exc: + with pytest.raises(AuthInvalidJwtError, match=r"Invalid JWT structure") as exc: decode_jwt("non-valid-jwt") assert exc.value is not None From 4ae36c4110a74589e331b25baae07411df913dcd Mon Sep 17 00:00:00 2001 From: Guilherme Souza <guilherme@supabase.io> Date: Mon, 10 Mar 2025 18:15:31 +0100 Subject: [PATCH 06/15] wip tests --- poetry.lock | 17 ++++++++++++ pyproject.toml | 1 + supabase_auth/_async/gotrue_client.py | 39 ++++++++++++++++++++++----- supabase_auth/constants.py | 1 + supabase_auth/helpers.py | 3 +-- supabase_auth/types.py | 9 +++++-- tests/_async/test_gotrue.py | 24 +++++++++++++++++ 7 files changed, 84 insertions(+), 10 deletions(-) create mode 100644 tests/_async/test_gotrue.py diff --git a/poetry.lock b/poetry.lock index d6f729f7..1020f882 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1122,6 +1122,23 @@ future-fstrings = "*" networkx = "*" pytest = ">=3" +[[package]] +name = "pytest-mock" +version = "3.14.0" +description = "Thin-wrapper around the mock package for easier use with pytest" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"}, + {file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"}, +] + +[package.dependencies] +pytest = ">=6.2.5" + +[package.extras] +dev = ["pre-commit", "pytest-asyncio", "tox"] + [[package]] name = "pyyaml" version = "6.0.2" diff --git a/pyproject.toml b/pyproject.toml index 09c7c90d..deaf4c90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ python = "^3.9" httpx = { version = ">=0.26,<0.29", extras = ["http2"] } pydantic = ">=1.10,<3" pyjwt = "^2.10.1" +pytest-mock = "^3.14.0" [tool.poetry.dev-dependencies] pytest = "^8.3.5" diff --git a/supabase_auth/_async/gotrue_client.py b/supabase_auth/_async/gotrue_client.py index 4069991b..a99a1133 100644 --- a/supabase_auth/_async/gotrue_client.py +++ b/supabase_auth/_async/gotrue_client.py @@ -4,7 +4,7 @@ from functools import partial from json import loads from time import time -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, TypedDict from urllib.parse import parse_qs, urlencode, urlparse from uuid import uuid4 @@ -28,6 +28,7 @@ decode_jwt, generate_pkce_challenge, generate_pkce_verifier, + get_algorithm, model_dump, model_dump_json, model_validate, @@ -36,6 +37,7 @@ parse_link_identity_response, parse_sso_response, parse_user_response, + validate_exp, ) from ..http_clients import AsyncClient from ..timer import Timer @@ -53,9 +55,12 @@ AuthMFAVerifyResponse, AuthOtpResponse, AuthResponse, + ClaimsResponse, CodeExchangeParams, DecodedJWTDict, IdentitiesResponse, + JWTHeader, + JWTPayload, MFAChallengeAndVerifyParams, MFAChallengeParams, MFAEnrollParams, @@ -1158,13 +1163,13 @@ async def exchange_code_for_session(self, params: CodeExchangeParams): async def _fetch_jwks(self, kid: str, jwks: JWKS) -> JWK: # try fetching from the suplied keys. - jwk = next((jwk for jwk in jwks["keys"] if jwk["kid"] == kid), None) + jwk = next((jwk for jwk in jwks.keys if jwk["kid"] == kid), None) if jwk: return jwk # try fetching from the cache. - jwk = next((jwk for jwk in self._jwks["keys"] if jwk["kid"] == kid), None) + jwk = next((jwk for jwk in self._jwks.keys if jwk["kid"] == kid), None) if jwk: return jwk @@ -1184,12 +1189,34 @@ async def _fetch_jwks(self, kid: str, jwks: JWKS) -> JWK: raise AuthInvalidJwtError("JWT has no valid kid") - async def get_claims(self): - pass + async def get_claims(self, jwt: Optional[str] = None, jwks: Optional[JWKS] = None) -> Optional[ClaimsResponse]: + token = jwt + if not token: + session = await self.get_session() + if not session: + return None + + token = session.access_token + + decoded_jwt = decode_jwt(token) + payload = decoded_jwt["payload"] + header = decoded_jwt["header"] + signature = decoded_jwt["signature"] + + validate_exp(payload["exp"]) + + # if symetric algorithm, fallback to get_user + if 'kid' not in header or header["alg"] == 'HS256': + await self.get_user(token) + return ClaimsResponse(claims=payload, headers=header, signature=signature) + + algorithm = get_algorithm(header["alg"]) + signing_key = await self._fetch_jwks(header["kid"], jwks) + algorithm.prepare_key def parse_jwks(response: Any) -> JWKS: if "keys" not in response or len(response.keys) == 0: raise AuthInvalidJwtError("JWKS is empty") - return JWKS(keys=response.keys) + return JWKS(keys=response["keys"]) diff --git a/supabase_auth/constants.py b/supabase_auth/constants.py index d16c0f31..5b0d1abc 100644 --- a/supabase_auth/constants.py +++ b/supabase_auth/constants.py @@ -21,3 +21,4 @@ "name": "2024-01-01", }, } +BASE64URL_REGEX = r"^([a-z0-9_-]{4})*($|[a-z0-9_-]{3}$|[a-z0-9_-]{2}$)$" \ No newline at end of file diff --git a/supabase_auth/helpers.py b/supabase_auth/helpers.py index 57cecf31..e46214c3 100644 --- a/supabase_auth/helpers.py +++ b/supabase_auth/helpers.py @@ -16,7 +16,7 @@ from httpx import HTTPStatusError, Response from pydantic import BaseModel -from .constants import API_VERSION_HEADER_NAME, API_VERSIONS +from .constants import API_VERSION_HEADER_NAME, API_VERSIONS, BASE64URL_REGEX from .errors import ( AuthApiError, AuthError, @@ -40,7 +40,6 @@ ) TBaseModel = TypeVar("TBaseModel", bound=BaseModel) -BASE64URL_REGEX = r"^([a-z0-9_-]{4})*($|[a-z0-9_-]{3}$|[a-z0-9_-]{2}$)$" def model_validate(model: Type[TBaseModel], contents) -> TBaseModel: diff --git a/supabase_auth/types.py b/supabase_auth/types.py index 5425173a..3d671cde 100644 --- a/supabase_auth/types.py +++ b/supabase_auth/types.py @@ -806,11 +806,11 @@ class RequiredClaims(TypedDict): session_id: str -class JWTPayload(RequiredClaims, TypedDict, total=False): +class JWTPayload(RequiredClaims, total=False): pass -class JWK(TypedDict, total=False): +class JWK(TypedDict): kty: Literal["RSA", "EC", "oct"] key_ops: List[str] alg: Optional[str] @@ -821,6 +821,11 @@ class JWKS(TypedDict): keys: List[JWK] +class ClaimsResponse(TypedDict): + claims: JWTPayload + headers: JWTHeader + signature: bytes + for model in [ AMREntry, AuthResponse, diff --git a/tests/_async/test_gotrue.py b/tests/_async/test_gotrue.py new file mode 100644 index 00000000..6d13addd --- /dev/null +++ b/tests/_async/test_gotrue.py @@ -0,0 +1,24 @@ +import pytest +from pytest_mock import mocker + +from .utils import mock_user_credentials + +from .clients import auth_client + +async def test_get_claims_returns_none_when_session_is_none(): + claims = await auth_client().get_claims() + assert claims is None + +async def test_get_claims_calls_get_user_if_symmetric_jwt(mocker): + client = auth_client() + credentials = mock_user_credentials() + spy = mocker.spy(client, 'get_user') + + response = await client.sign_up(credentials) + user = response.user + assert user is not None + + response = await client.get_claims() + assert response["claims"]["email"] == user.email + spy.assert_called_once() + \ No newline at end of file From 8683823423dcf0c896eabb5483e4e78d6c9dd56b Mon Sep 17 00:00:00 2001 From: Guilherme Souza <guilherme@supabase.io> Date: Wed, 12 Mar 2025 11:06:46 +0100 Subject: [PATCH 07/15] tests for asymmetric signature --- supabase_auth/_async/gotrue_client.py | 72 +++++++++++++++--------- supabase_auth/_sync/gotrue_client.py | 79 +++++++++++++++++++++------ supabase_auth/constants.py | 2 +- supabase_auth/helpers.py | 11 +--- supabase_auth/types.py | 12 +--- tests/_async/clients.py | 12 ++++ tests/_async/test_gotrue.py | 48 ++++++++++------ tests/_sync/clients.py | 12 ++++ tests/_sync/test_gotrue.py | 41 ++++++++++++++ 9 files changed, 207 insertions(+), 82 deletions(-) create mode 100644 tests/_sync/test_gotrue.py diff --git a/supabase_auth/_async/gotrue_client.py b/supabase_auth/_async/gotrue_client.py index a99a1133..2fa8154c 100644 --- a/supabase_auth/_async/gotrue_client.py +++ b/supabase_auth/_async/gotrue_client.py @@ -4,10 +4,12 @@ from functools import partial from json import loads from time import time -from typing import Any, Callable, Dict, List, Optional, Tuple, TypedDict +from typing import Any, Callable, Dict, List, Optional, Tuple from urllib.parse import parse_qs, urlencode, urlparse from uuid import uuid4 +from jwt import get_algorithm_by_name + from ..constants import ( DEFAULT_HEADERS, EXPIRY_MARGIN, @@ -28,7 +30,6 @@ decode_jwt, generate_pkce_challenge, generate_pkce_verifier, - get_algorithm, model_dump, model_dump_json, model_validate, @@ -42,8 +43,6 @@ from ..http_clients import AsyncClient from ..timer import Timer from ..types import ( - JWK, - JWKS, AuthChangeEvent, AuthenticatorAssuranceLevels, AuthFlowType, @@ -59,8 +58,6 @@ CodeExchangeParams, DecodedJWTDict, IdentitiesResponse, - JWTHeader, - JWTPayload, MFAChallengeAndVerifyParams, MFAChallengeParams, MFAEnrollParams, @@ -114,7 +111,7 @@ def __init__( verify=verify, proxy=proxy, ) - self._jwks: JWKS = {} + self._jwks = {"keys": []} self._storage_key = storage_key or STORAGE_KEY self._auto_refresh_token = auto_refresh_token self._persist_session = persist_session @@ -1161,26 +1158,32 @@ async def exchange_code_for_session(self, params: CodeExchangeParams): self._notify_all_subscribers("SIGNED_IN", response.session) return response - async def _fetch_jwks(self, kid: str, jwks: JWKS) -> JWK: + async def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]: + jwk: Dict[str, Any] = {} + # try fetching from the suplied keys. - jwk = next((jwk for jwk in jwks.keys if jwk["kid"] == kid), None) + jwk = next((jwk for jwk in jwks.get("keys", []) if jwk.get("kid") == kid), None) if jwk: return jwk - # try fetching from the cache. - jwk = next((jwk for jwk in self._jwks.keys if jwk["kid"] == kid), None) - if jwk: - return jwk + if self._jwks: + # try fetching from the cache. + jwk = next( + (jwk for jwk in self._jwks.get("keys", []) if jwk.get("kid") == kid), + None, + ) + if jwk: + return jwk # jwk isn't cached in memory so we need to fetch it from the well-known endpoint response = await self._request("GET", ".well-known/jwks.json", xform=parse_jwks) - if response.jwks: - self._jwks = response.jwks + if response: + self._jwks = response # find the signing key jwk = next( - (jwk for jwk in response.jwks["keys"] if jwk["kid"] == kid), None + (jwk for jwk in response.get("keys", []) if jwk.get("kid") == kid), None ) if not jwk: raise AuthInvalidJwtError("No matching signing key found in JWKS") @@ -1189,13 +1192,15 @@ async def _fetch_jwks(self, kid: str, jwks: JWKS) -> JWK: raise AuthInvalidJwtError("JWT has no valid kid") - async def get_claims(self, jwt: Optional[str] = None, jwks: Optional[JWKS] = None) -> Optional[ClaimsResponse]: + async def get_claims( + self, jwt: Optional[str] = None, jwks: Optional[Dict[str, list]] = None + ) -> Optional[ClaimsResponse]: token = jwt if not token: session = await self.get_session() if not session: return None - + token = session.access_token decoded_jwt = decode_jwt(token) @@ -1203,20 +1208,35 @@ async def get_claims(self, jwt: Optional[str] = None, jwks: Optional[JWKS] = Non header = decoded_jwt["header"] signature = decoded_jwt["signature"] + raw_header = decoded_jwt["raw"]["header"] + raw_payload = decoded_jwt["raw"]["payload"] + validate_exp(payload["exp"]) - # if symetric algorithm, fallback to get_user - if 'kid' not in header or header["alg"] == 'HS256': + # if symmetric algorithm, fallback to get_user + if "kid" not in header or header["alg"] == "HS256": await self.get_user(token) return ClaimsResponse(claims=payload, headers=header, signature=signature) - algorithm = get_algorithm(header["alg"]) - signing_key = await self._fetch_jwks(header["kid"], jwks) - algorithm.prepare_key + algorithm = get_algorithm_by_name(header["alg"]) + signing_key = algorithm.from_jwk( + await self._fetch_jwks(header["kid"], jwks or {}) + ) + + # verify the signature + is_valid = algorithm.verify( + msg=f"{raw_header}.{raw_payload}".encode(), key=signing_key, sig=signature + ) + + if not is_valid: + raise AuthInvalidJwtError("Invalid JWT signature") + + # If verification succeeds, decode and return claims + return ClaimsResponse(claims=payload, headers=header, signature=signature) -def parse_jwks(response: Any) -> JWKS: - if "keys" not in response or len(response.keys) == 0: +def parse_jwks(response: Any) -> Dict[str, list]: + if "keys" not in response or len(response["keys"]) == 0: raise AuthInvalidJwtError("JWKS is empty") - return JWKS(keys=response["keys"]) + return response diff --git a/supabase_auth/_sync/gotrue_client.py b/supabase_auth/_sync/gotrue_client.py index 6e1cbed4..41c2de5d 100644 --- a/supabase_auth/_sync/gotrue_client.py +++ b/supabase_auth/_sync/gotrue_client.py @@ -8,6 +8,8 @@ from urllib.parse import parse_qs, urlencode, urlparse from uuid import uuid4 +from jwt import get_algorithm_by_name + from ..constants import ( DEFAULT_HEADERS, EXPIRY_MARGIN, @@ -36,12 +38,11 @@ parse_link_identity_response, parse_sso_response, parse_user_response, + validate_exp, ) from ..http_clients import SyncClient from ..timer import Timer from ..types import ( - JWK, - JWKS, AuthChangeEvent, AuthenticatorAssuranceLevels, AuthFlowType, @@ -53,6 +54,7 @@ AuthMFAVerifyResponse, AuthOtpResponse, AuthResponse, + ClaimsResponse, CodeExchangeParams, DecodedJWTDict, IdentitiesResponse, @@ -109,7 +111,7 @@ def __init__( verify=verify, proxy=proxy, ) - self._jwks: JWKS = {} + self._jwks = {"keys": []} self._storage_key = storage_key or STORAGE_KEY self._auto_refresh_token = auto_refresh_token self._persist_session = persist_session @@ -1150,26 +1152,32 @@ def exchange_code_for_session(self, params: CodeExchangeParams): self._notify_all_subscribers("SIGNED_IN", response.session) return response - def _fetch_jwks(self, kid: str, jwks: JWKS) -> JWK: + def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]: + jwk: Dict[str, Any] = {} + # try fetching from the suplied keys. - jwk = next((jwk for jwk in jwks["keys"] if jwk["kid"] == kid), None) + jwk = next((jwk for jwk in jwks.get("keys", []) if jwk.get("kid") == kid), None) if jwk: return jwk - # try fetching from the cache. - jwk = next((jwk for jwk in self._jwks["keys"] if jwk["kid"] == kid), None) - if jwk: - return jwk + if self._jwks: + # try fetching from the cache. + jwk = next( + (jwk for jwk in self._jwks.get("keys", []) if jwk.get("kid") == kid), + None, + ) + if jwk: + return jwk # jwk isn't cached in memory so we need to fetch it from the well-known endpoint response = self._request("GET", ".well-known/jwks.json", xform=parse_jwks) - if response.jwks: - self._jwks = response.jwks + if response: + self._jwks = response # find the signing key jwk = next( - (jwk for jwk in response.jwks["keys"] if jwk["kid"] == kid), None + (jwk for jwk in response.get("keys", []) if jwk.get("kid") == kid), None ) if not jwk: raise AuthInvalidJwtError("No matching signing key found in JWKS") @@ -1178,12 +1186,49 @@ def _fetch_jwks(self, kid: str, jwks: JWKS) -> JWK: raise AuthInvalidJwtError("JWT has no valid kid") - def get_claims(self): - pass + def get_claims( + self, jwt: Optional[str] = None, jwks: Optional[Dict[str, list]] = None + ) -> Optional[ClaimsResponse]: + token = jwt + if not token: + session = self.get_session() + if not session: + return None + + token = session.access_token + + decoded_jwt = decode_jwt(token) + payload = decoded_jwt["payload"] + header = decoded_jwt["header"] + signature = decoded_jwt["signature"] + + raw_header = decoded_jwt["raw"]["header"] + raw_payload = decoded_jwt["raw"]["payload"] + + validate_exp(payload["exp"]) + + # if symmetric algorithm, fallback to get_user + if "kid" not in header or header["alg"] == "HS256": + self.get_user(token) + return ClaimsResponse(claims=payload, headers=header, signature=signature) + + algorithm = get_algorithm_by_name(header["alg"]) + signing_key = algorithm.from_jwk(self._fetch_jwks(header["kid"], jwks or {})) + + # verify the signature + is_valid = algorithm.verify( + msg=f"{raw_header}.{raw_payload}".encode(), key=signing_key, sig=signature + ) + + if not is_valid: + raise AuthInvalidJwtError("Invalid JWT signature") + + # If verification succeeds, decode and return claims + return ClaimsResponse(claims=payload, headers=header, signature=signature) -def parse_jwks(response: Any) -> JWKS: - if "keys" not in response or len(response.keys) == 0: +def parse_jwks(response: Any) -> Dict[str, list]: + if "keys" not in response or len(response["keys"]) == 0: raise AuthInvalidJwtError("JWKS is empty") - return JWKS(keys=response.keys) + return response diff --git a/supabase_auth/constants.py b/supabase_auth/constants.py index 5b0d1abc..671510e5 100644 --- a/supabase_auth/constants.py +++ b/supabase_auth/constants.py @@ -21,4 +21,4 @@ "name": "2024-01-01", }, } -BASE64URL_REGEX = r"^([a-z0-9_-]{4})*($|[a-z0-9_-]{3}$|[a-z0-9_-]{2}$)$" \ No newline at end of file +BASE64URL_REGEX = r"^([a-z0-9_-]{4})*($|[a-z0-9_-]{3}$|[a-z0-9_-]{2}$)$" diff --git a/supabase_auth/helpers.py b/supabase_auth/helpers.py index e46214c3..65c68ad4 100644 --- a/supabase_auth/helpers.py +++ b/supabase_auth/helpers.py @@ -8,11 +8,9 @@ from base64 import urlsafe_b64decode from datetime import datetime from json import loads -from typing import Any, Dict, Literal, Optional, Type, TypedDict, TypeVar, cast +from typing import Any, Dict, Optional, Type, TypedDict, TypeVar, cast from urllib.parse import urlparse -import jwt -import jwt.algorithms from httpx import HTTPStatusError, Response from pydantic import BaseModel @@ -311,10 +309,3 @@ def validate_exp(exp: int) -> None: time_now = datetime.now().timestamp() if exp <= time_now: raise AuthInvalidJwtError("JWT has expired") - - -def get_algorithm(alg: Literal["RS256", "ES256"]) -> jwt.algorithms.Algorithm: - if alg == "RS256": - return jwt.algorithms.RSAAlgorithm - elif alg == "ES256": - return jwt.algorithms.ECAlgorithm diff --git a/supabase_auth/types.py b/supabase_auth/types.py index 3d671cde..4ef3c1ac 100644 --- a/supabase_auth/types.py +++ b/supabase_auth/types.py @@ -810,22 +810,12 @@ class JWTPayload(RequiredClaims, total=False): pass -class JWK(TypedDict): - kty: Literal["RSA", "EC", "oct"] - key_ops: List[str] - alg: Optional[str] - kid: Optional[str] - - -class JWKS(TypedDict): - keys: List[JWK] - - class ClaimsResponse(TypedDict): claims: JWTPayload headers: JWTHeader signature: bytes + for model in [ AMREntry, AuthResponse, diff --git a/tests/_async/clients.py b/tests/_async/clients.py index 3babf2ee..03356714 100644 --- a/tests/_async/clients.py +++ b/tests/_async/clients.py @@ -5,6 +5,7 @@ SIGNUP_ENABLED_AUTO_CONFIRM_OFF_PORT = 9999 SIGNUP_ENABLED_AUTO_CONFIRM_ON_PORT = 9998 SIGNUP_DISABLED_AUTO_CONFIRM_OFF_PORT = 9997 +SIGNUP_ENABLED_ASYMMETRIC_AUTO_CONFIRM_ON_PORT = 9996 GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_OFF = ( f"http://localhost:{SIGNUP_ENABLED_AUTO_CONFIRM_OFF_PORT}" @@ -12,6 +13,9 @@ GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_ON = ( f"http://localhost:{SIGNUP_ENABLED_AUTO_CONFIRM_ON_PORT}" ) +GOTRUE_URL_SIGNUP_ENABLED_ASYMMETRIC_AUTO_CONFIRM_ON = ( + f"http://localhost:{SIGNUP_ENABLED_ASYMMETRIC_AUTO_CONFIRM_ON_PORT}" +) GOTRUE_URL_SIGNUP_DISABLED_AUTO_CONFIRM_OFF = ( f"http://localhost:{SIGNUP_DISABLED_AUTO_CONFIRM_OFF_PORT}" ) @@ -43,6 +47,14 @@ def auth_client_with_session(): ) +def auth_client_with_asymmetric_session() -> AsyncGoTrueClient: + return AsyncGoTrueClient( + url=GOTRUE_URL_SIGNUP_ENABLED_ASYMMETRIC_AUTO_CONFIRM_ON, + auto_refresh_token=False, + persist_session=False, + ) + + def auth_subscription_client(): return AsyncGoTrueClient( url=GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_ON, diff --git a/tests/_async/test_gotrue.py b/tests/_async/test_gotrue.py index 6d13addd..05adf67a 100644 --- a/tests/_async/test_gotrue.py +++ b/tests/_async/test_gotrue.py @@ -1,24 +1,38 @@ -import pytest -from pytest_mock import mocker +import unittest +from .clients import auth_client, auth_client_with_asymmetric_session from .utils import mock_user_credentials -from .clients import auth_client async def test_get_claims_returns_none_when_session_is_none(): - claims = await auth_client().get_claims() - assert claims is None + claims = await auth_client().get_claims() + assert claims is None + async def test_get_claims_calls_get_user_if_symmetric_jwt(mocker): - client = auth_client() - credentials = mock_user_credentials() - spy = mocker.spy(client, 'get_user') - - response = await client.sign_up(credentials) - user = response.user - assert user is not None - - response = await client.get_claims() - assert response["claims"]["email"] == user.email - spy.assert_called_once() - \ No newline at end of file + client = auth_client() + spy = mocker.spy(client, "get_user") + + user = (await client.sign_up(mock_user_credentials())).user + assert user is not None + + claims = (await client.get_claims())["claims"] + assert claims["email"] == user.email + spy.assert_called_once() + + +async def test_get_claims_fetches_jwks_to_verify_asymmetric_jwt(mocker): + client = auth_client_with_asymmetric_session() + + user = (await client.sign_up(mock_user_credentials())).user + assert user is not None + + spy = mocker.spy(client, "_request") + + claims = (await client.get_claims())["claims"] + assert claims["email"] == user.email + + spy.assert_called_once() + spy.assert_called_with("GET", ".well-known/jwks.json", xform=unittest.mock.ANY) + + assert len(spy.spy_return.get("keys")) > 0 diff --git a/tests/_sync/clients.py b/tests/_sync/clients.py index 50086c2a..3fee59d1 100644 --- a/tests/_sync/clients.py +++ b/tests/_sync/clients.py @@ -5,6 +5,7 @@ SIGNUP_ENABLED_AUTO_CONFIRM_OFF_PORT = 9999 SIGNUP_ENABLED_AUTO_CONFIRM_ON_PORT = 9998 SIGNUP_DISABLED_AUTO_CONFIRM_OFF_PORT = 9997 +SIGNUP_ENABLED_ASYMMETRIC_AUTO_CONFIRM_ON_PORT = 9996 GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_OFF = ( f"http://localhost:{SIGNUP_ENABLED_AUTO_CONFIRM_OFF_PORT}" @@ -12,6 +13,9 @@ GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_ON = ( f"http://localhost:{SIGNUP_ENABLED_AUTO_CONFIRM_ON_PORT}" ) +GOTRUE_URL_SIGNUP_ENABLED_ASYMMETRIC_AUTO_CONFIRM_ON = ( + f"http://localhost:{SIGNUP_ENABLED_ASYMMETRIC_AUTO_CONFIRM_ON_PORT}" +) GOTRUE_URL_SIGNUP_DISABLED_AUTO_CONFIRM_OFF = ( f"http://localhost:{SIGNUP_DISABLED_AUTO_CONFIRM_OFF_PORT}" ) @@ -43,6 +47,14 @@ def auth_client_with_session(): ) +def auth_client_with_asymmetric_session() -> SyncGoTrueClient: + return SyncGoTrueClient( + url=GOTRUE_URL_SIGNUP_ENABLED_ASYMMETRIC_AUTO_CONFIRM_ON, + auto_refresh_token=False, + persist_session=False, + ) + + def auth_subscription_client(): return SyncGoTrueClient( url=GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_ON, diff --git a/tests/_sync/test_gotrue.py b/tests/_sync/test_gotrue.py new file mode 100644 index 00000000..c0b576ad --- /dev/null +++ b/tests/_sync/test_gotrue.py @@ -0,0 +1,41 @@ +import unittest +import pytest +from pytest_mock import mocker + +from .utils import mock_user_credentials + +from .clients import auth_client, auth_client_with_asymmetric_session + +def test_get_claims_returns_none_when_session_is_none(): + claims = auth_client().get_claims() + assert claims is None + +def test_get_claims_calls_get_user_if_symmetric_jwt(mocker): + client = auth_client() + spy = mocker.spy(client, 'get_user') + + user = (client.sign_up(mock_user_credentials())).user + assert user is not None + + claims = (client.get_claims())["claims"] + assert claims["email"] == user.email + spy.assert_called_once() + + +def test_get_claims_fetches_jwks_to_verify_asymmetric_jwt(mocker): + client = auth_client_with_asymmetric_session() + + user = (client.sign_up(mock_user_credentials())).user + assert user is not None + + spy = mocker.spy(client, "_request") + + claims = (client.get_claims())["claims"] + assert claims["email"] == user.email + + spy.assert_called_once() + spy.assert_called_with("GET", ".well-known/jwks.json", xform=unittest.mock.ANY) + + assert len(spy.spy_return.get("keys")) > 0 + + \ No newline at end of file From 3d92584bb15f57017a798f3c3b6e88e5681bec4f Mon Sep 17 00:00:00 2001 From: Guilherme Souza <guilherme@supabase.io> Date: Wed, 12 Mar 2025 11:20:23 +0100 Subject: [PATCH 08/15] add typed dicts for convenience --- supabase_auth/_async/gotrue_client.py | 38 +++++++++++--------- supabase_auth/_sync/gotrue_client.py | 52 ++++++++++++++++----------- supabase_auth/types.py | 11 ++++++ tests/_sync/test_gotrue.py | 45 +++++++++++------------ tests/_sync/test_gotrue_admin_api.py | 34 +++++++++--------- 5 files changed, 101 insertions(+), 79 deletions(-) diff --git a/supabase_auth/_async/gotrue_client.py b/supabase_auth/_async/gotrue_client.py index 2fa8154c..452ee612 100644 --- a/supabase_auth/_async/gotrue_client.py +++ b/supabase_auth/_async/gotrue_client.py @@ -43,6 +43,7 @@ from ..http_clients import AsyncClient from ..timer import Timer from ..types import ( + JWK, AuthChangeEvent, AuthenticatorAssuranceLevels, AuthFlowType, @@ -58,6 +59,7 @@ CodeExchangeParams, DecodedJWTDict, IdentitiesResponse, + JWKSet, MFAChallengeAndVerifyParams, MFAChallengeParams, MFAEnrollParams, @@ -111,7 +113,7 @@ def __init__( verify=verify, proxy=proxy, ) - self._jwks = {"keys": []} + self._jwks: JWKSet = {"keys": []} self._storage_key = storage_key or STORAGE_KEY self._auto_refresh_token = auto_refresh_token self._persist_session = persist_session @@ -1158,11 +1160,11 @@ async def exchange_code_for_session(self, params: CodeExchangeParams): self._notify_all_subscribers("SIGNED_IN", response.session) return response - async def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]: - jwk: Dict[str, Any] = {} + async def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK: + jwk: Optional[JWK] = None # try fetching from the suplied keys. - jwk = next((jwk for jwk in jwks.get("keys", []) if jwk.get("kid") == kid), None) + jwk = next((jwk for jwk in jwks["keys"] if jwk["kid"] == kid), None) if jwk: return jwk @@ -1170,7 +1172,7 @@ async def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]: if self._jwks: # try fetching from the cache. jwk = next( - (jwk for jwk in self._jwks.get("keys", []) if jwk.get("kid") == kid), + (jwk for jwk in self._jwks["keys"] if jwk["kid"] == kid), None, ) if jwk: @@ -1182,9 +1184,7 @@ async def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]: self._jwks = response # find the signing key - jwk = next( - (jwk for jwk in response.get("keys", []) if jwk.get("kid") == kid), None - ) + jwk = next((jwk for jwk in response["keys"] if jwk["kid"] == kid), None) if not jwk: raise AuthInvalidJwtError("No matching signing key found in JWKS") @@ -1193,7 +1193,7 @@ async def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]: raise AuthInvalidJwtError("JWT has no valid kid") async def get_claims( - self, jwt: Optional[str] = None, jwks: Optional[Dict[str, list]] = None + self, jwt: Optional[str] = None, jwks: Optional[JWKSet] = None ) -> Optional[ClaimsResponse]: token = jwt if not token: @@ -1204,12 +1204,16 @@ async def get_claims( token = session.access_token decoded_jwt = decode_jwt(token) - payload = decoded_jwt["payload"] - header = decoded_jwt["header"] - signature = decoded_jwt["signature"] - raw_header = decoded_jwt["raw"]["header"] - raw_payload = decoded_jwt["raw"]["payload"] + payload, header, signature = ( + decoded_jwt["payload"], + decoded_jwt["header"], + decoded_jwt["signature"], + ) + raw_header, raw_payload = ( + decoded_jwt["raw"]["header"], + decoded_jwt["raw"]["payload"], + ) validate_exp(payload["exp"]) @@ -1220,7 +1224,7 @@ async def get_claims( algorithm = get_algorithm_by_name(header["alg"]) signing_key = algorithm.from_jwk( - await self._fetch_jwks(header["kid"], jwks or {}) + await self._fetch_jwks(header["kid"], jwks or {"keys": []}) ) # verify the signature @@ -1235,8 +1239,8 @@ async def get_claims( return ClaimsResponse(claims=payload, headers=header, signature=signature) -def parse_jwks(response: Any) -> Dict[str, list]: +def parse_jwks(response: Any) -> JWKSet: if "keys" not in response or len(response["keys"]) == 0: raise AuthInvalidJwtError("JWKS is empty") - return response + return {"keys": response["keys"]} diff --git a/supabase_auth/_sync/gotrue_client.py b/supabase_auth/_sync/gotrue_client.py index 41c2de5d..9a4404f8 100644 --- a/supabase_auth/_sync/gotrue_client.py +++ b/supabase_auth/_sync/gotrue_client.py @@ -43,6 +43,7 @@ from ..http_clients import SyncClient from ..timer import Timer from ..types import ( + JWK, AuthChangeEvent, AuthenticatorAssuranceLevels, AuthFlowType, @@ -58,6 +59,7 @@ CodeExchangeParams, DecodedJWTDict, IdentitiesResponse, + JWKSet, MFAChallengeAndVerifyParams, MFAChallengeParams, MFAEnrollParams, @@ -111,7 +113,7 @@ def __init__( verify=verify, proxy=proxy, ) - self._jwks = {"keys": []} + self._jwks: JWKSet = {"keys": []} self._storage_key = storage_key or STORAGE_KEY self._auto_refresh_token = auto_refresh_token self._persist_session = persist_session @@ -421,7 +423,9 @@ def sign_in_with_oauth( ) return OAuthResponse(provider=provider, url=url_with_qs) - def link_identity(self, credentials: SignInWithOAuthCredentials) -> OAuthResponse: + def link_identity( + self, credentials: SignInWithOAuthCredentials + ) -> OAuthResponse: provider = credentials.get("provider") options = credentials.get("options", {}) redirect_to = options.get("redirect_to") @@ -704,7 +708,9 @@ def set_session(self, access_token: str, refresh_token: str) -> AuthResponse: self._notify_all_subscribers("TOKEN_REFRESHED", session) return AuthResponse(session=session, user=response.user) - def refresh_session(self, refresh_token: Optional[str] = None) -> AuthResponse: + def refresh_session( + self, refresh_token: Optional[str] = None + ) -> AuthResponse: """ Returns a new session, regardless of expiry status. @@ -1113,7 +1119,9 @@ def _get_url_for_provider( if self._flow_type == "pkce": code_verifier = generate_pkce_verifier() code_challenge = generate_pkce_challenge(code_verifier) - self._storage.set_item(f"{self._storage_key}-code-verifier", code_verifier) + self._storage.set_item( + f"{self._storage_key}-code-verifier", code_verifier + ) code_challenge_method = ( "plain" if code_verifier == code_challenge else "s256" ) @@ -1152,11 +1160,11 @@ def exchange_code_for_session(self, params: CodeExchangeParams): self._notify_all_subscribers("SIGNED_IN", response.session) return response - def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]: - jwk: Dict[str, Any] = {} + def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK: + jwk: Optional[JWK] = None # try fetching from the suplied keys. - jwk = next((jwk for jwk in jwks.get("keys", []) if jwk.get("kid") == kid), None) + jwk = next((jwk for jwk in jwks["keys"] if jwk["kid"] == kid), None) if jwk: return jwk @@ -1164,7 +1172,7 @@ def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]: if self._jwks: # try fetching from the cache. jwk = next( - (jwk for jwk in self._jwks.get("keys", []) if jwk.get("kid") == kid), + (jwk for jwk in self._jwks["keys"] if jwk["kid"] == kid), None, ) if jwk: @@ -1176,9 +1184,7 @@ def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]: self._jwks = response # find the signing key - jwk = next( - (jwk for jwk in response.get("keys", []) if jwk.get("kid") == kid), None - ) + jwk = next((jwk for jwk in response["keys"] if jwk["kid"] == kid), None) if not jwk: raise AuthInvalidJwtError("No matching signing key found in JWKS") @@ -1187,7 +1193,7 @@ def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]: raise AuthInvalidJwtError("JWT has no valid kid") def get_claims( - self, jwt: Optional[str] = None, jwks: Optional[Dict[str, list]] = None + self, jwt: Optional[str] = None, jwks: Optional[JWKSet] = None ) -> Optional[ClaimsResponse]: token = jwt if not token: @@ -1198,12 +1204,16 @@ def get_claims( token = session.access_token decoded_jwt = decode_jwt(token) - payload = decoded_jwt["payload"] - header = decoded_jwt["header"] - signature = decoded_jwt["signature"] - raw_header = decoded_jwt["raw"]["header"] - raw_payload = decoded_jwt["raw"]["payload"] + payload, header, signature = ( + decoded_jwt["payload"], + decoded_jwt["header"], + decoded_jwt["signature"], + ) + raw_header, raw_payload = ( + decoded_jwt["raw"]["header"], + decoded_jwt["raw"]["payload"], + ) validate_exp(payload["exp"]) @@ -1213,7 +1223,9 @@ def get_claims( return ClaimsResponse(claims=payload, headers=header, signature=signature) algorithm = get_algorithm_by_name(header["alg"]) - signing_key = algorithm.from_jwk(self._fetch_jwks(header["kid"], jwks or {})) + signing_key = algorithm.from_jwk( + self._fetch_jwks(header["kid"], jwks or {"keys": []}) + ) # verify the signature is_valid = algorithm.verify( @@ -1227,8 +1239,8 @@ def get_claims( return ClaimsResponse(claims=payload, headers=header, signature=signature) -def parse_jwks(response: Any) -> Dict[str, list]: +def parse_jwks(response: Any) -> JWKSet: if "keys" not in response or len(response["keys"]) == 0: raise AuthInvalidJwtError("JWKS is empty") - return response + return {"keys": response["keys"]} diff --git a/supabase_auth/types.py b/supabase_auth/types.py index 4ef3c1ac..991a27b4 100644 --- a/supabase_auth/types.py +++ b/supabase_auth/types.py @@ -816,6 +816,17 @@ class ClaimsResponse(TypedDict): signature: bytes +class JWK(TypedDict, total=False): + kty: Literal["RSA", "EC", "oct"] + key_ops: List[str] + alg: Optional[str] + kid: Optional[str] + + +class JWKSet(TypedDict): + keys: List[JWK] + + for model in [ AMREntry, AuthResponse, diff --git a/tests/_sync/test_gotrue.py b/tests/_sync/test_gotrue.py index c0b576ad..0f1082b4 100644 --- a/tests/_sync/test_gotrue.py +++ b/tests/_sync/test_gotrue.py @@ -1,41 +1,38 @@ import unittest -import pytest -from pytest_mock import mocker +from .clients import auth_client, auth_client_with_asymmetric_session from .utils import mock_user_credentials -from .clients import auth_client, auth_client_with_asymmetric_session def test_get_claims_returns_none_when_session_is_none(): - claims = auth_client().get_claims() - assert claims is None + claims = auth_client().get_claims() + assert claims is None + def test_get_claims_calls_get_user_if_symmetric_jwt(mocker): - client = auth_client() - spy = mocker.spy(client, 'get_user') + client = auth_client() + spy = mocker.spy(client, "get_user") - user = (client.sign_up(mock_user_credentials())).user - assert user is not None + user = (client.sign_up(mock_user_credentials())).user + assert user is not None - claims = (client.get_claims())["claims"] - assert claims["email"] == user.email - spy.assert_called_once() - + claims = (client.get_claims())["claims"] + assert claims["email"] == user.email + spy.assert_called_once() -def test_get_claims_fetches_jwks_to_verify_asymmetric_jwt(mocker): - client = auth_client_with_asymmetric_session() - user = (client.sign_up(mock_user_credentials())).user - assert user is not None +def test_get_claims_fetches_jwks_to_verify_asymmetric_jwt(mocker): + client = auth_client_with_asymmetric_session() - spy = mocker.spy(client, "_request") + user = (client.sign_up(mock_user_credentials())).user + assert user is not None - claims = (client.get_claims())["claims"] - assert claims["email"] == user.email + spy = mocker.spy(client, "_request") - spy.assert_called_once() - spy.assert_called_with("GET", ".well-known/jwks.json", xform=unittest.mock.ANY) + claims = (client.get_claims())["claims"] + assert claims["email"] == user.email - assert len(spy.spy_return.get("keys")) > 0 + spy.assert_called_once() + spy.assert_called_with("GET", ".well-known/jwks.json", xform=unittest.mock.ANY) - \ No newline at end of file + assert len(spy.spy_return.get("keys")) > 0 diff --git a/tests/_sync/test_gotrue_admin_api.py b/tests/_sync/test_gotrue_admin_api.py index dd5fa523..2df5352a 100644 --- a/tests/_sync/test_gotrue_admin_api.py +++ b/tests/_sync/test_gotrue_admin_api.py @@ -156,13 +156,11 @@ def test_modify_confirm_email_using_update_user_by_id(): def test_invalid_credential_sign_in_with_phone(): try: - response = ( - client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password( - { - "phone": "+123456789", - "password": "strong_pwd", - } - ) + response = client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password( + { + "phone": "+123456789", + "password": "strong_pwd", + } ) except AuthApiError as e: assert e.to_dict() @@ -170,13 +168,11 @@ def test_invalid_credential_sign_in_with_phone(): def test_invalid_credential_sign_in_with_email(): try: - response = ( - client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password( - { - "email": "unknown_user@unknowndomain.com", - "password": "strong_pwd", - } - ) + response = client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password( + { + "email": "unknown_user@unknowndomain.com", + "password": "strong_pwd", + } ) except AuthApiError as e: assert e.to_dict() @@ -390,10 +386,12 @@ def test_sign_in_with_sso(): def test_sign_in_with_oauth(): - assert client_api_auto_confirm_off_signups_enabled_client().sign_in_with_oauth( - { - "provider": "google", - } + assert ( + client_api_auto_confirm_off_signups_enabled_client().sign_in_with_oauth( + { + "provider": "google", + } + ) ) From 46c4c373a1b608ea1f4801c25f812fa8588a11c4 Mon Sep 17 00:00:00 2001 From: Guilherme Souza <guilherme@supabase.io> Date: Wed, 12 Mar 2025 11:22:01 +0100 Subject: [PATCH 09/15] poetry lock --- poetry.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index 1020f882..8d1c337b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1481,4 +1481,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "b43d599434e02d17b2037adbe56a99a89a948127d180acdd27139e748047c798" +content-hash = "f6e106fec6347fbb9991f6587f507e8ff984db84afd392cd5ee5d5ab682f833b" From c001f79b22c9d4c60862cf8d6397aaeee850dda6 Mon Sep 17 00:00:00 2001 From: Guilherme Souza <guilherme@supabase.io> Date: Wed, 12 Mar 2025 11:55:00 +0100 Subject: [PATCH 10/15] fix code format --- supabase_auth/_sync/gotrue_client.py | 12 +++------- tests/_sync/test_gotrue_admin_api.py | 34 +++++++++++++++------------- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/supabase_auth/_sync/gotrue_client.py b/supabase_auth/_sync/gotrue_client.py index 9a4404f8..e0535f08 100644 --- a/supabase_auth/_sync/gotrue_client.py +++ b/supabase_auth/_sync/gotrue_client.py @@ -423,9 +423,7 @@ def sign_in_with_oauth( ) return OAuthResponse(provider=provider, url=url_with_qs) - def link_identity( - self, credentials: SignInWithOAuthCredentials - ) -> OAuthResponse: + def link_identity(self, credentials: SignInWithOAuthCredentials) -> OAuthResponse: provider = credentials.get("provider") options = credentials.get("options", {}) redirect_to = options.get("redirect_to") @@ -708,9 +706,7 @@ def set_session(self, access_token: str, refresh_token: str) -> AuthResponse: self._notify_all_subscribers("TOKEN_REFRESHED", session) return AuthResponse(session=session, user=response.user) - def refresh_session( - self, refresh_token: Optional[str] = None - ) -> AuthResponse: + def refresh_session(self, refresh_token: Optional[str] = None) -> AuthResponse: """ Returns a new session, regardless of expiry status. @@ -1119,9 +1115,7 @@ def _get_url_for_provider( if self._flow_type == "pkce": code_verifier = generate_pkce_verifier() code_challenge = generate_pkce_challenge(code_verifier) - self._storage.set_item( - f"{self._storage_key}-code-verifier", code_verifier - ) + self._storage.set_item(f"{self._storage_key}-code-verifier", code_verifier) code_challenge_method = ( "plain" if code_verifier == code_challenge else "s256" ) diff --git a/tests/_sync/test_gotrue_admin_api.py b/tests/_sync/test_gotrue_admin_api.py index 2df5352a..dd5fa523 100644 --- a/tests/_sync/test_gotrue_admin_api.py +++ b/tests/_sync/test_gotrue_admin_api.py @@ -156,11 +156,13 @@ def test_modify_confirm_email_using_update_user_by_id(): def test_invalid_credential_sign_in_with_phone(): try: - response = client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password( - { - "phone": "+123456789", - "password": "strong_pwd", - } + response = ( + client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password( + { + "phone": "+123456789", + "password": "strong_pwd", + } + ) ) except AuthApiError as e: assert e.to_dict() @@ -168,11 +170,13 @@ def test_invalid_credential_sign_in_with_phone(): def test_invalid_credential_sign_in_with_email(): try: - response = client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password( - { - "email": "unknown_user@unknowndomain.com", - "password": "strong_pwd", - } + response = ( + client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password( + { + "email": "unknown_user@unknowndomain.com", + "password": "strong_pwd", + } + ) ) except AuthApiError as e: assert e.to_dict() @@ -386,12 +390,10 @@ def test_sign_in_with_sso(): def test_sign_in_with_oauth(): - assert ( - client_api_auto_confirm_off_signups_enabled_client().sign_in_with_oauth( - { - "provider": "google", - } - ) + assert client_api_auto_confirm_off_signups_enabled_client().sign_in_with_oauth( + { + "provider": "google", + } ) From 1d72f0740bb4cbe79b143357b948b8a71746208b Mon Sep 17 00:00:00 2001 From: Guilherme Souza <guilherme@supabase.io> Date: Wed, 12 Mar 2025 14:25:42 +0100 Subject: [PATCH 11/15] improve test by asserting cache keys --- supabase_auth/_async/gotrue_client.py | 10 ++-------- supabase_auth/_sync/gotrue_client.py | 10 ++-------- supabase_auth/helpers.py | 8 ++++++++ tests/_async/test_gotrue.py | 5 ++++- tests/_sync/test_gotrue.py | 5 ++++- 5 files changed, 20 insertions(+), 18 deletions(-) diff --git a/supabase_auth/_async/gotrue_client.py b/supabase_auth/_async/gotrue_client.py index 452ee612..2669817d 100644 --- a/supabase_auth/_async/gotrue_client.py +++ b/supabase_auth/_async/gotrue_client.py @@ -4,7 +4,7 @@ from functools import partial from json import loads from time import time -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple from urllib.parse import parse_qs, urlencode, urlparse from uuid import uuid4 @@ -35,6 +35,7 @@ model_validate, parse_auth_otp_response, parse_auth_response, + parse_jwks, parse_link_identity_response, parse_sso_response, parse_user_response, @@ -1237,10 +1238,3 @@ async def get_claims( # If verification succeeds, decode and return claims return ClaimsResponse(claims=payload, headers=header, signature=signature) - - -def parse_jwks(response: Any) -> JWKSet: - if "keys" not in response or len(response["keys"]) == 0: - raise AuthInvalidJwtError("JWKS is empty") - - return {"keys": response["keys"]} diff --git a/supabase_auth/_sync/gotrue_client.py b/supabase_auth/_sync/gotrue_client.py index e0535f08..c161cc22 100644 --- a/supabase_auth/_sync/gotrue_client.py +++ b/supabase_auth/_sync/gotrue_client.py @@ -4,7 +4,7 @@ from functools import partial from json import loads from time import time -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple from urllib.parse import parse_qs, urlencode, urlparse from uuid import uuid4 @@ -35,6 +35,7 @@ model_validate, parse_auth_otp_response, parse_auth_response, + parse_jwks, parse_link_identity_response, parse_sso_response, parse_user_response, @@ -1231,10 +1232,3 @@ def get_claims( # If verification succeeds, decode and return claims return ClaimsResponse(claims=payload, headers=header, signature=signature) - - -def parse_jwks(response: Any) -> JWKSet: - if "keys" not in response or len(response["keys"]) == 0: - raise AuthInvalidJwtError("JWKS is empty") - - return {"keys": response["keys"]} diff --git a/supabase_auth/helpers.py b/supabase_auth/helpers.py index 65c68ad4..9aebab0d 100644 --- a/supabase_auth/helpers.py +++ b/supabase_auth/helpers.py @@ -28,6 +28,7 @@ AuthResponse, GenerateLinkProperties, GenerateLinkResponse, + JWKSet, JWTHeader, JWTPayload, LinkIdentityResponse, @@ -119,6 +120,13 @@ def parse_sso_response(data: Any) -> SSOResponse: return model_validate(SSOResponse, data) +def parse_jwks(response: Any) -> JWKSet: + if "keys" not in response or len(response["keys"]) == 0: + raise AuthInvalidJwtError("JWKS is empty") + + return {"keys": response["keys"]} + + def get_error_message(error: Any) -> str: props = ["msg", "message", "error_description", "error"] filter = lambda prop: ( diff --git a/tests/_async/test_gotrue.py b/tests/_async/test_gotrue.py index 05adf67a..4c39c29e 100644 --- a/tests/_async/test_gotrue.py +++ b/tests/_async/test_gotrue.py @@ -35,4 +35,7 @@ async def test_get_claims_fetches_jwks_to_verify_asymmetric_jwt(mocker): spy.assert_called_once() spy.assert_called_with("GET", ".well-known/jwks.json", xform=unittest.mock.ANY) - assert len(spy.spy_return.get("keys")) > 0 + expected_keyid = "638c54b8-28c2-4b12-9598-ba12ef610a29" + + assert len(client._jwks["keys"]) == 1 + assert client._jwks["keys"][0]["kid"] == expected_keyid diff --git a/tests/_sync/test_gotrue.py b/tests/_sync/test_gotrue.py index 0f1082b4..0ccfd032 100644 --- a/tests/_sync/test_gotrue.py +++ b/tests/_sync/test_gotrue.py @@ -35,4 +35,7 @@ def test_get_claims_fetches_jwks_to_verify_asymmetric_jwt(mocker): spy.assert_called_once() spy.assert_called_with("GET", ".well-known/jwks.json", xform=unittest.mock.ANY) - assert len(spy.spy_return.get("keys")) > 0 + expected_keyid = "638c54b8-28c2-4b12-9598-ba12ef610a29" + + assert len(client._jwks["keys"]) == 1 + assert client._jwks["keys"][0]["kid"] == expected_keyid From c7a19033574ca93ad77ce176b1d1d5cbd1d0bbb0 Mon Sep 17 00:00:00 2001 From: Guilherme Souza <guilherme@supabase.io> Date: Fri, 14 Mar 2025 10:57:21 +0100 Subject: [PATCH 12/15] add jwks cache ttl --- supabase_auth/_async/gotrue_client.py | 9 ++++++++- supabase_auth/_sync/gotrue_client.py | 9 ++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/supabase_auth/_async/gotrue_client.py b/supabase_auth/_async/gotrue_client.py index 2669817d..aca2d468 100644 --- a/supabase_auth/_async/gotrue_client.py +++ b/supabase_auth/_async/gotrue_client.py @@ -114,7 +114,11 @@ def __init__( verify=verify, proxy=proxy, ) + self._jwks: JWKSet = {"keys": []} + self._jwks_ttl: float = 600 # 10 minutes + self._jwks_cached_at: Optional[float] = None + self._storage_key = storage_key or STORAGE_KEY self._auto_refresh_token = auto_refresh_token self._persist_session = persist_session @@ -1170,7 +1174,9 @@ async def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK: if jwk: return jwk - if self._jwks: + if self._jwks and ( + self._jwks_cached_at and self._jwks_cached_at + self._jwks_ttl < time() + ): # try fetching from the cache. jwk = next( (jwk for jwk in self._jwks["keys"] if jwk["kid"] == kid), @@ -1183,6 +1189,7 @@ async def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK: response = await self._request("GET", ".well-known/jwks.json", xform=parse_jwks) if response: self._jwks = response + self._jwks_cached_at = time() # find the signing key jwk = next((jwk for jwk in response["keys"] if jwk["kid"] == kid), None) diff --git a/supabase_auth/_sync/gotrue_client.py b/supabase_auth/_sync/gotrue_client.py index c161cc22..0817198c 100644 --- a/supabase_auth/_sync/gotrue_client.py +++ b/supabase_auth/_sync/gotrue_client.py @@ -114,7 +114,11 @@ def __init__( verify=verify, proxy=proxy, ) + self._jwks: JWKSet = {"keys": []} + self._jwks_ttl: float = 600 # 10 minutes + self._jwks_cached_at: Optional[float] = None + self._storage_key = storage_key or STORAGE_KEY self._auto_refresh_token = auto_refresh_token self._persist_session = persist_session @@ -1164,7 +1168,9 @@ def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK: if jwk: return jwk - if self._jwks: + if self._jwks and ( + self._jwks_cached_at and self._jwks_cached_at + self._jwks_ttl < time() + ): # try fetching from the cache. jwk = next( (jwk for jwk in self._jwks["keys"] if jwk["kid"] == kid), @@ -1177,6 +1183,7 @@ def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK: response = self._request("GET", ".well-known/jwks.json", xform=parse_jwks) if response: self._jwks = response + self._jwks_cached_at = time() # find the signing key jwk = next((jwk for jwk in response["keys"] if jwk["kid"] == kid), None) From 2a4bad56b6d0c32fe4f0d681a2c4a81ebceb4bb9 Mon Sep 17 00:00:00 2001 From: Guilherme Souza <guilherme@supabase.io> Date: Fri, 14 Mar 2025 11:22:29 +0100 Subject: [PATCH 13/15] jwks cache ttl tests --- supabase_auth/_async/gotrue_client.py | 16 +++++++------- supabase_auth/_sync/gotrue_client.py | 16 +++++++------- tests/_async/test_gotrue.py | 32 +++++++++++++++++++++++++++ tests/_sync/test_gotrue.py | 32 +++++++++++++++++++++++++++ 4 files changed, 80 insertions(+), 16 deletions(-) diff --git a/supabase_auth/_async/gotrue_client.py b/supabase_auth/_async/gotrue_client.py index aca2d468..14baf2a7 100644 --- a/supabase_auth/_async/gotrue_client.py +++ b/supabase_auth/_async/gotrue_client.py @@ -1,9 +1,9 @@ from __future__ import annotations +import time from contextlib import suppress from functools import partial from json import loads -from time import time from typing import Callable, Dict, List, Optional, Tuple from urllib.parse import parse_qs, urlencode, urlparse from uuid import uuid4 @@ -621,7 +621,7 @@ async def get_session(self) -> Optional[Session]: current_session = self._in_memory_session if not current_session: return None - time_now = round(time()) + time_now = round(time.time()) has_expired = ( current_session.expires_at <= time_now + EXPIRY_MARGIN if current_session.expires_at @@ -682,7 +682,7 @@ async def set_session(self, access_token: str, refresh_token: str) -> AuthRespon The current session that minimally contains an access token, refresh token and a user. """ - time_now = round(time()) + time_now = round(time.time()) expires_at = time_now has_expired = True session: Optional[Session] = None @@ -959,7 +959,7 @@ async def _get_session_from_url( token_type = self._get_param(params, "token_type") if not token_type: raise AuthImplicitGrantRedirectError("No token_type detected.") - time_now = round(time()) + time_now = round(time.time()) expires_at = time_now + int(expires_in) user = await self.get_user(access_token) session = Session( @@ -982,7 +982,7 @@ async def _recover_and_refresh(self) -> None: if raw_session: await self._remove_session() return - time_now = round(time()) + time_now = round(time.time()) expires_at = current_session.expires_at if expires_at and expires_at < time_now + EXPIRY_MARGIN: refresh_token = current_session.refresh_token @@ -1034,7 +1034,7 @@ async def _save_session(self, session: Session) -> None: self._in_memory_session = session expire_at = session.expires_at if expire_at: - time_now = round(time()) + time_now = round(time.time()) expire_in = expire_at - time_now refresh_duration_before_expires = ( EXPIRY_MARGIN if expire_in > EXPIRY_MARGIN else 0.5 @@ -1175,7 +1175,7 @@ async def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK: return jwk if self._jwks and ( - self._jwks_cached_at and self._jwks_cached_at + self._jwks_ttl < time() + self._jwks_cached_at and time.time() - self._jwks_cached_at < self._jwks_ttl ): # try fetching from the cache. jwk = next( @@ -1189,7 +1189,7 @@ async def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK: response = await self._request("GET", ".well-known/jwks.json", xform=parse_jwks) if response: self._jwks = response - self._jwks_cached_at = time() + self._jwks_cached_at = time.time() # find the signing key jwk = next((jwk for jwk in response["keys"] if jwk["kid"] == kid), None) diff --git a/supabase_auth/_sync/gotrue_client.py b/supabase_auth/_sync/gotrue_client.py index 0817198c..a5ce08b1 100644 --- a/supabase_auth/_sync/gotrue_client.py +++ b/supabase_auth/_sync/gotrue_client.py @@ -1,9 +1,9 @@ from __future__ import annotations +import time from contextlib import suppress from functools import partial from json import loads -from time import time from typing import Callable, Dict, List, Optional, Tuple from urllib.parse import parse_qs, urlencode, urlparse from uuid import uuid4 @@ -619,7 +619,7 @@ def get_session(self) -> Optional[Session]: current_session = self._in_memory_session if not current_session: return None - time_now = round(time()) + time_now = round(time.time()) has_expired = ( current_session.expires_at <= time_now + EXPIRY_MARGIN if current_session.expires_at @@ -680,7 +680,7 @@ def set_session(self, access_token: str, refresh_token: str) -> AuthResponse: The current session that minimally contains an access token, refresh token and a user. """ - time_now = round(time()) + time_now = round(time.time()) expires_at = time_now has_expired = True session: Optional[Session] = None @@ -955,7 +955,7 @@ def _get_session_from_url( token_type = self._get_param(params, "token_type") if not token_type: raise AuthImplicitGrantRedirectError("No token_type detected.") - time_now = round(time()) + time_now = round(time.time()) expires_at = time_now + int(expires_in) user = self.get_user(access_token) session = Session( @@ -978,7 +978,7 @@ def _recover_and_refresh(self) -> None: if raw_session: self._remove_session() return - time_now = round(time()) + time_now = round(time.time()) expires_at = current_session.expires_at if expires_at and expires_at < time_now + EXPIRY_MARGIN: refresh_token = current_session.refresh_token @@ -1030,7 +1030,7 @@ def _save_session(self, session: Session) -> None: self._in_memory_session = session expire_at = session.expires_at if expire_at: - time_now = round(time()) + time_now = round(time.time()) expire_in = expire_at - time_now refresh_duration_before_expires = ( EXPIRY_MARGIN if expire_in > EXPIRY_MARGIN else 0.5 @@ -1169,7 +1169,7 @@ def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK: return jwk if self._jwks and ( - self._jwks_cached_at and self._jwks_cached_at + self._jwks_ttl < time() + self._jwks_cached_at and time.time() - self._jwks_cached_at < self._jwks_ttl ): # try fetching from the cache. jwk = next( @@ -1183,7 +1183,7 @@ def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK: response = self._request("GET", ".well-known/jwks.json", xform=parse_jwks) if response: self._jwks = response - self._jwks_cached_at = time() + self._jwks_cached_at = time.time() # find the signing key jwk = next((jwk for jwk in response["keys"] if jwk["kid"] == kid), None) diff --git a/tests/_async/test_gotrue.py b/tests/_async/test_gotrue.py index 4c39c29e..53f7a487 100644 --- a/tests/_async/test_gotrue.py +++ b/tests/_async/test_gotrue.py @@ -1,3 +1,4 @@ +import time import unittest from .clients import auth_client, auth_client_with_asymmetric_session @@ -39,3 +40,34 @@ async def test_get_claims_fetches_jwks_to_verify_asymmetric_jwt(mocker): assert len(client._jwks["keys"]) == 1 assert client._jwks["keys"][0]["kid"] == expected_keyid + + +async def test_jwks_ttl_cache_behavior(mocker): + client = auth_client_with_asymmetric_session() + + spy = mocker.spy(client, "_request") + + # First call should fetch JWKS from endpoint + user = (await client.sign_up(mock_user_credentials())).user + assert user is not None + + await client.get_claims() + spy.assert_called_with("GET", ".well-known/jwks.json", xform=unittest.mock.ANY) + first_call_count = spy.call_count + + # Second call within TTL should use cache + await client.get_claims() + assert spy.call_count == first_call_count # No additional JWKS request + + # Mock time to be after TTL expiry + original_time = time.time + try: + mock_time = mocker.patch("time.time") + mock_time.return_value = original_time() + 601 # TTL is 600 seconds + + # Call after TTL expiry should fetch fresh JWKS + await client.get_claims() + assert spy.call_count == first_call_count + 1 # One more JWKS request + finally: + # Restore original time function + mocker.patch("time.time", original_time) diff --git a/tests/_sync/test_gotrue.py b/tests/_sync/test_gotrue.py index 0ccfd032..0fb77a6e 100644 --- a/tests/_sync/test_gotrue.py +++ b/tests/_sync/test_gotrue.py @@ -1,3 +1,4 @@ +import time import unittest from .clients import auth_client, auth_client_with_asymmetric_session @@ -39,3 +40,34 @@ def test_get_claims_fetches_jwks_to_verify_asymmetric_jwt(mocker): assert len(client._jwks["keys"]) == 1 assert client._jwks["keys"][0]["kid"] == expected_keyid + + +def test_jwks_ttl_cache_behavior(mocker): + client = auth_client_with_asymmetric_session() + + spy = mocker.spy(client, "_request") + + # First call should fetch JWKS from endpoint + user = (client.sign_up(mock_user_credentials())).user + assert user is not None + + client.get_claims() + spy.assert_called_with("GET", ".well-known/jwks.json", xform=unittest.mock.ANY) + first_call_count = spy.call_count + + # Second call within TTL should use cache + client.get_claims() + assert spy.call_count == first_call_count # No additional JWKS request + + # Mock time to be after TTL expiry + original_time = time.time + try: + mock_time = mocker.patch("time.time") + mock_time.return_value = original_time() + 601 # TTL is 600 seconds + + # Call after TTL expiry should fetch fresh JWKS + client.get_claims() + assert spy.call_count == first_call_count + 1 # One more JWKS request + finally: + # Restore original time function + mocker.patch("time.time", original_time) From 6a90f66a526a53029c566743fe3663bb4ffcd6e8 Mon Sep 17 00:00:00 2001 From: Guilherme Souza <guilherme@supabase.io> Date: Tue, 18 Mar 2025 15:01:28 -0300 Subject: [PATCH 14/15] fix base64 token validation and add tests for decode_jwt --- supabase_auth/_async/gotrue_client.py | 12 +-- supabase_auth/_sync/gotrue_client.py | 12 +-- supabase_auth/helpers.py | 6 +- tests/_async/test_gotrue.py | 120 +++++++++++++++++++++++++- tests/_sync/test_gotrue.py | 120 +++++++++++++++++++++++++- 5 files changed, 245 insertions(+), 25 deletions(-) diff --git a/supabase_auth/_async/gotrue_client.py b/supabase_auth/_async/gotrue_client.py index 14baf2a7..6c450839 100644 --- a/supabase_auth/_async/gotrue_client.py +++ b/supabase_auth/_async/gotrue_client.py @@ -58,7 +58,6 @@ AuthResponse, ClaimsResponse, CodeExchangeParams, - DecodedJWTDict, IdentitiesResponse, JWKSet, MFAChallengeAndVerifyParams, @@ -687,7 +686,7 @@ async def set_session(self, access_token: str, refresh_token: str) -> AuthRespon has_expired = True session: Optional[Session] = None if access_token and access_token.split(".")[1]: - payload = self._decode_jwt(access_token) + payload = decode_jwt(access_token)["payload"] exp = payload.get("exp") if exp: expires_at = int(exp) @@ -899,7 +898,7 @@ async def _get_authenticator_assurance_level( next_level=None, current_authentication_methods=[], ) - payload = self._decode_jwt(session.access_token) + payload = decode_jwt(session.access_token)["payload"] current_level: Optional[AuthenticatorAssuranceLevels] = None if payload.get("aal"): current_level = payload.get("aal") @@ -1137,13 +1136,6 @@ async def _get_url_for_provider( query = urlencode(params) return f"{url}?{query}", params - def _decode_jwt(self, jwt: str) -> DecodedJWTDict: - """ - Decodes a JWT (without performing any validation). - """ - decoded = decode_jwt(jwt) - return decoded["payload"] - async def exchange_code_for_session(self, params: CodeExchangeParams): code_verifier = params.get("code_verifier") or await self._storage.get_item( f"{self._storage_key}-code-verifier" diff --git a/supabase_auth/_sync/gotrue_client.py b/supabase_auth/_sync/gotrue_client.py index a5ce08b1..fabae947 100644 --- a/supabase_auth/_sync/gotrue_client.py +++ b/supabase_auth/_sync/gotrue_client.py @@ -58,7 +58,6 @@ AuthResponse, ClaimsResponse, CodeExchangeParams, - DecodedJWTDict, IdentitiesResponse, JWKSet, MFAChallengeAndVerifyParams, @@ -685,7 +684,7 @@ def set_session(self, access_token: str, refresh_token: str) -> AuthResponse: has_expired = True session: Optional[Session] = None if access_token and access_token.split(".")[1]: - payload = self._decode_jwt(access_token) + payload = decode_jwt(access_token)["payload"] exp = payload.get("exp") if exp: expires_at = int(exp) @@ -895,7 +894,7 @@ def _get_authenticator_assurance_level( next_level=None, current_authentication_methods=[], ) - payload = self._decode_jwt(session.access_token) + payload = decode_jwt(session.access_token)["payload"] current_level: Optional[AuthenticatorAssuranceLevels] = None if payload.get("aal"): current_level = payload.get("aal") @@ -1131,13 +1130,6 @@ def _get_url_for_provider( query = urlencode(params) return f"{url}?{query}", params - def _decode_jwt(self, jwt: str) -> DecodedJWTDict: - """ - Decodes a JWT (without performing any validation). - """ - decoded = decode_jwt(jwt) - return decoded["payload"] - def exchange_code_for_session(self, params: CodeExchangeParams): code_verifier = params.get("code_verifier") or self._storage.get_item( f"{self._storage_key}-code-verifier" diff --git a/supabase_auth/helpers.py b/supabase_auth/helpers.py index 9aebab0d..9407e630 100644 --- a/supabase_auth/helpers.py +++ b/supabase_auth/helpers.py @@ -229,9 +229,9 @@ def decode_jwt(token: str) -> DecodedJWT: raise AuthInvalidJwtError("Invalid JWT structure") # regex check for base64url - # for part in parts: - # if not re.match(BASE64URL_REGEX, part): - # raise AuthInvalidJwtError("JWT not in base64url format") + for part in parts: + if not re.match(BASE64URL_REGEX, part, re.IGNORECASE): + raise AuthInvalidJwtError("JWT not in base64url format") return DecodedJWT( header=JWTHeader(**loads(str_from_base64url(parts[0]))), diff --git a/tests/_async/test_gotrue.py b/tests/_async/test_gotrue.py index 53f7a487..2a14a562 100644 --- a/tests/_async/test_gotrue.py +++ b/tests/_async/test_gotrue.py @@ -1,7 +1,13 @@ import time import unittest -from .clients import auth_client, auth_client_with_asymmetric_session +import pytest +from jwt import encode + +from supabase_auth.errors import AuthInvalidJwtError, AuthSessionMissingError +from supabase_auth.helpers import decode_jwt + +from .clients import GOTRUE_JWT_SECRET, auth_client, auth_client_with_asymmetric_session from .utils import mock_user_credentials @@ -71,3 +77,115 @@ async def test_jwks_ttl_cache_behavior(mocker): finally: # Restore original time function mocker.patch("time.time", original_time) + + +async def test_set_session_with_valid_tokens(): + client = auth_client() + credentials = mock_user_credentials() + + # First sign up to get valid tokens + signup_response = await client.sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + } + ) + assert signup_response.session is not None + + # Get the tokens from the signup response + access_token = signup_response.session.access_token + refresh_token = signup_response.session.refresh_token + + # Clear the session + await client._remove_session() + + # Set the session with the tokens + response = await client.set_session(access_token, refresh_token) + + # Verify the response + assert response.session is not None + assert response.session.access_token == access_token + assert response.session.refresh_token == refresh_token + assert response.user is not None + assert response.user.email == credentials.get("email") + + +async def test_set_session_with_expired_token(): + client = auth_client() + credentials = mock_user_credentials() + + # First sign up to get valid tokens + signup_response = await client.sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + } + ) + assert signup_response.session is not None + + # Get the tokens from the signup response + access_token = signup_response.session.access_token + refresh_token = signup_response.session.refresh_token + + # Clear the session + await client._remove_session() + + # Create an expired token by modifying the JWT + expired_token = access_token.split(".") + payload = decode_jwt(access_token)["payload"] + payload["exp"] = int(time.time()) - 3600 # Set expiry to 1 hour ago + expired_token[1] = encode(payload, GOTRUE_JWT_SECRET, algorithm="HS256").split(".")[ + 1 + ] + expired_access_token = ".".join(expired_token) + + # Set the session with the expired token + response = await client.set_session(expired_access_token, refresh_token) + + # Verify the response has a new access token (refreshed) + assert response.session is not None + assert response.session.access_token != expired_access_token + assert response.session.refresh_token != refresh_token + assert response.user is not None + assert response.user.email == credentials.get("email") + + +async def test_set_session_without_refresh_token(): + client = auth_client() + credentials = mock_user_credentials() + + # First sign up to get valid tokens + signup_response = await client.sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + } + ) + assert signup_response.session is not None + + # Get the access token from the signup response + access_token = signup_response.session.access_token + + # Clear the session + await client._remove_session() + + # Create an expired token + expired_token = access_token.split(".") + payload = decode_jwt(access_token)["payload"] + payload["exp"] = int(time.time()) - 3600 # Set expiry to 1 hour ago + expired_token[1] = encode(payload, GOTRUE_JWT_SECRET, algorithm="HS256").split(".")[ + 1 + ] + expired_access_token = ".".join(expired_token) + + # Try to set the session with an expired token but no refresh token + with pytest.raises(AuthSessionMissingError): + await client.set_session(expired_access_token, "") + + +async def test_set_session_with_invalid_token(): + client = auth_client() + + # Try to set the session with invalid tokens + with pytest.raises(AuthInvalidJwtError): + await client.set_session("invalid.token.here", "invalid_refresh_token") diff --git a/tests/_sync/test_gotrue.py b/tests/_sync/test_gotrue.py index 0fb77a6e..b68fc9e2 100644 --- a/tests/_sync/test_gotrue.py +++ b/tests/_sync/test_gotrue.py @@ -1,7 +1,13 @@ import time import unittest -from .clients import auth_client, auth_client_with_asymmetric_session +import pytest +from jwt import encode + +from supabase_auth.errors import AuthInvalidJwtError, AuthSessionMissingError +from supabase_auth.helpers import decode_jwt + +from .clients import GOTRUE_JWT_SECRET, auth_client, auth_client_with_asymmetric_session from .utils import mock_user_credentials @@ -71,3 +77,115 @@ def test_jwks_ttl_cache_behavior(mocker): finally: # Restore original time function mocker.patch("time.time", original_time) + + +def test_set_session_with_valid_tokens(): + client = auth_client() + credentials = mock_user_credentials() + + # First sign up to get valid tokens + signup_response = client.sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + } + ) + assert signup_response.session is not None + + # Get the tokens from the signup response + access_token = signup_response.session.access_token + refresh_token = signup_response.session.refresh_token + + # Clear the session + client._remove_session() + + # Set the session with the tokens + response = client.set_session(access_token, refresh_token) + + # Verify the response + assert response.session is not None + assert response.session.access_token == access_token + assert response.session.refresh_token == refresh_token + assert response.user is not None + assert response.user.email == credentials.get("email") + + +def test_set_session_with_expired_token(): + client = auth_client() + credentials = mock_user_credentials() + + # First sign up to get valid tokens + signup_response = client.sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + } + ) + assert signup_response.session is not None + + # Get the tokens from the signup response + access_token = signup_response.session.access_token + refresh_token = signup_response.session.refresh_token + + # Clear the session + client._remove_session() + + # Create an expired token by modifying the JWT + expired_token = access_token.split(".") + payload = decode_jwt(access_token)["payload"] + payload["exp"] = int(time.time()) - 3600 # Set expiry to 1 hour ago + expired_token[1] = encode(payload, GOTRUE_JWT_SECRET, algorithm="HS256").split(".")[ + 1 + ] + expired_access_token = ".".join(expired_token) + + # Set the session with the expired token + response = client.set_session(expired_access_token, refresh_token) + + # Verify the response has a new access token (refreshed) + assert response.session is not None + assert response.session.access_token != expired_access_token + assert response.session.refresh_token != refresh_token + assert response.user is not None + assert response.user.email == credentials.get("email") + + +def test_set_session_without_refresh_token(): + client = auth_client() + credentials = mock_user_credentials() + + # First sign up to get valid tokens + signup_response = client.sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + } + ) + assert signup_response.session is not None + + # Get the access token from the signup response + access_token = signup_response.session.access_token + + # Clear the session + client._remove_session() + + # Create an expired token + expired_token = access_token.split(".") + payload = decode_jwt(access_token)["payload"] + payload["exp"] = int(time.time()) - 3600 # Set expiry to 1 hour ago + expired_token[1] = encode(payload, GOTRUE_JWT_SECRET, algorithm="HS256").split(".")[ + 1 + ] + expired_access_token = ".".join(expired_token) + + # Try to set the session with an expired token but no refresh token + with pytest.raises(AuthSessionMissingError): + client.set_session(expired_access_token, "") + + +def test_set_session_with_invalid_token(): + client = auth_client() + + # Try to set the session with invalid tokens + with pytest.raises(AuthInvalidJwtError): + client.set_session("invalid.token.here", "invalid_refresh_token") From 66ffe5d432aca0381ac3d3825e9cd1da4830d66d Mon Sep 17 00:00:00 2001 From: Guilherme Souza <guilherme@supabase.io> Date: Tue, 18 Mar 2025 15:01:45 -0300 Subject: [PATCH 15/15] clean up _refresh_token_timer --- supabase_auth/_async/gotrue_client.py | 13 +++++++++++++ supabase_auth/_sync/gotrue_client.py | 13 +++++++++++++ 2 files changed, 26 insertions(+) diff --git a/supabase_auth/_async/gotrue_client.py b/supabase_auth/_async/gotrue_client.py index 6c450839..02befb90 100644 --- a/supabase_auth/_async/gotrue_client.py +++ b/supabase_auth/_async/gotrue_client.py @@ -1237,3 +1237,16 @@ async def get_claims( # If verification succeeds, decode and return claims return ClaimsResponse(claims=payload, headers=header, signature=signature) + + def __del__(self) -> None: + """Clean up resources when the client is destroyed.""" + if self._refresh_token_timer: + try: + # Try to cancel the timer + self._refresh_token_timer.cancel() + except: + # Ignore errors if event loop is closed or selector is not registered + pass + finally: + # Always set to None to prevent further attempts + self._refresh_token_timer = None diff --git a/supabase_auth/_sync/gotrue_client.py b/supabase_auth/_sync/gotrue_client.py index fabae947..169ebd0e 100644 --- a/supabase_auth/_sync/gotrue_client.py +++ b/supabase_auth/_sync/gotrue_client.py @@ -1231,3 +1231,16 @@ def get_claims( # If verification succeeds, decode and return claims return ClaimsResponse(claims=payload, headers=header, signature=signature) + + def __del__(self) -> None: + """Clean up resources when the client is destroyed.""" + if self._refresh_token_timer: + try: + # Try to cancel the timer + self._refresh_token_timer.cancel() + except: + # Ignore errors if event loop is closed or selector is not registered + pass + finally: + # Always set to None to prevent further attempts + self._refresh_token_timer = None