From 36565bf86d12fe39753605384a410ac0d8ea0d4f Mon Sep 17 00:00:00 2001
From: Guilherme Souza <guilherme@supabase.io>
Date: Tue, 25 Feb 2025 15:17:07 -0300
Subject: [PATCH 01/15] feat: add get_claims method

---
 poetry.lock                           | 10 ++---
 pyproject.toml                        |  1 +
 supabase_auth/_async/gotrue_client.py | 49 +++++++++++++++++++++--
 supabase_auth/errors.py               |  8 ++++
 supabase_auth/helpers.py              | 56 +++++++++++++++++++++++----
 supabase_auth/types.py                | 30 ++++++++++++++
 6 files changed, 139 insertions(+), 15 deletions(-)

diff --git a/poetry.lock b/poetry.lock
index aecdb529..d3c96613 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
 
 [[package]]
 name = "annotated-types"
@@ -1003,13 +1003,13 @@ urllib3 = ">=1.26.0"
 
 [[package]]
 name = "pyjwt"
-version = "2.9.0"
+version = "2.10.1"
 description = "JSON Web Token implementation in Python"
 optional = false
-python-versions = ">=3.8"
+python-versions = ">=3.9"
 files = [
-    {file = "PyJWT-2.9.0-py3-none-any.whl", hash = "sha256:3b02fb0f44517787776cf48f2ae25d8e14f300e6d7545a4315cee571a415e850"},
-    {file = "pyjwt-2.9.0.tar.gz", hash = "sha256:7e1e5b56cc735432a7369cbfa0efe50fa113ebecdc04ae6922deba8b84582d0c"},
+    {file = "PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb"},
+    {file = "pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953"},
 ]
 
 [package.dependencies]
diff --git a/pyproject.toml b/pyproject.toml
index 28dff73c..d6873904 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -18,6 +18,7 @@ classifiers = [
 python = "^3.9"
 httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
 pydantic = ">=1.10,<3"
+pyjwt = "^2.10.1"
 
 [tool.poetry.dev-dependencies]
 pytest = "^8.3.5"
diff --git a/supabase_auth/_async/gotrue_client.py b/supabase_auth/_async/gotrue_client.py
index 2a2a5564..4069991b 100644
--- a/supabase_auth/_async/gotrue_client.py
+++ b/supabase_auth/_async/gotrue_client.py
@@ -4,7 +4,7 @@
 from functools import partial
 from json import loads
 from time import time
-from typing import Callable, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
 from urllib.parse import parse_qs, urlencode, urlparse
 from uuid import uuid4
 
@@ -20,11 +20,12 @@
     AuthApiError,
     AuthImplicitGrantRedirectError,
     AuthInvalidCredentialsError,
+    AuthInvalidJwtError,
     AuthRetryableError,
     AuthSessionMissingError,
 )
 from ..helpers import (
-    decode_jwt_payload,
+    decode_jwt,
     generate_pkce_challenge,
     generate_pkce_verifier,
     model_dump,
@@ -39,6 +40,8 @@
 from ..http_clients import AsyncClient
 from ..timer import Timer
 from ..types import (
+    JWK,
+    JWKS,
     AuthChangeEvent,
     AuthenticatorAssuranceLevels,
     AuthFlowType,
@@ -106,6 +109,7 @@ def __init__(
             verify=verify,
             proxy=proxy,
         )
+        self._jwks: JWKS = {}
         self._storage_key = storage_key or STORAGE_KEY
         self._auto_refresh_token = auto_refresh_token
         self._persist_session = persist_session
@@ -1128,7 +1132,8 @@ def _decode_jwt(self, jwt: str) -> DecodedJWTDict:
         """
         Decodes a JWT (without performing any validation).
         """
-        return decode_jwt_payload(jwt)
+        decoded = decode_jwt(jwt)
+        return decoded["payload"]
 
     async def exchange_code_for_session(self, params: CodeExchangeParams):
         code_verifier = params.get("code_verifier") or await self._storage.get_item(
@@ -1150,3 +1155,41 @@ async def exchange_code_for_session(self, params: CodeExchangeParams):
             await self._save_session(response.session)
             self._notify_all_subscribers("SIGNED_IN", response.session)
         return response
+
+    async def _fetch_jwks(self, kid: str, jwks: JWKS) -> JWK:
+        # try fetching from the suplied keys.
+        jwk = next((jwk for jwk in jwks["keys"] if jwk["kid"] == kid), None)
+
+        if jwk:
+            return jwk
+
+        # try fetching from the cache.
+        jwk = next((jwk for jwk in self._jwks["keys"] if jwk["kid"] == kid), None)
+        if jwk:
+            return jwk
+
+        # jwk isn't cached in memory so we need to fetch it from the well-known endpoint
+        response = await self._request("GET", ".well-known/jwks.json", xform=parse_jwks)
+        if response.jwks:
+            self._jwks = response.jwks
+
+            # find the signing key
+            jwk = next(
+                (jwk for jwk in response.jwks["keys"] if jwk["kid"] == kid), None
+            )
+            if not jwk:
+                raise AuthInvalidJwtError("No matching signing key found in JWKS")
+
+            return jwk
+
+        raise AuthInvalidJwtError("JWT has no valid kid")
+
+    async def get_claims(self):
+        pass
+
+
+def parse_jwks(response: Any) -> JWKS:
+    if "keys" not in response or len(response.keys) == 0:
+        raise AuthInvalidJwtError("JWKS is empty")
+
+    return JWKS(keys=response.keys)
diff --git a/supabase_auth/errors.py b/supabase_auth/errors.py
index fa693894..1cf8de3c 100644
--- a/supabase_auth/errors.py
+++ b/supabase_auth/errors.py
@@ -225,3 +225,11 @@ def to_dict(self) -> AuthApiErrorDict:
             "status": self.status,
             "reasons": self.reasons,
         }
+
+class AuthInvalidJwtError(CustomAuthError):
+    def __init__(self, message: str) -> None:
+        CustomAuthError.__init__(
+            self,
+            message,
+            "AuthInvalidJwtError",
+            "invalid_jwt",
\ No newline at end of file
diff --git a/supabase_auth/helpers.py b/supabase_auth/helpers.py
index ae9dc7c5..8be212e2 100644
--- a/supabase_auth/helpers.py
+++ b/supabase_auth/helpers.py
@@ -8,16 +8,19 @@
 from base64 import urlsafe_b64decode
 from datetime import datetime
 from json import loads
-from typing import Any, Dict, Optional, Type, TypeVar, cast
+from typing import Any, Callable, Dict, Literal, Optional, Type, TypeVar, cast
 from urllib.parse import urlparse
 
 from httpx import HTTPStatusError, Response
+import jwt
+import jwt.algorithms
 from pydantic import BaseModel
 
 from .constants import API_VERSION_HEADER_NAME, API_VERSIONS
 from .errors import (
     AuthApiError,
     AuthError,
+    AuthInvalidJwtError,
     AuthRetryableError,
     AuthUnknownError,
     AuthWeakPasswordError,
@@ -192,15 +195,38 @@ def handle_exception(exception: Exception) -> AuthError:
         return AuthUnknownError(get_error_message(error), e)
 
 
-def decode_jwt_payload(token: str) -> Any:
-    parts = token.split(".")
-    if len(parts) != 3:
-        raise ValueError("JWT is not valid: not a JWT structure")
-    base64url = parts[1]
+def str_from_base64url(base64url: str) -> str:
     # Addding padding otherwise the following error happens:
     # binascii.Error: Incorrect padding
     base64url_with_padding = base64url + "=" * (-len(base64url) % 4)
-    return loads(urlsafe_b64decode(base64url_with_padding).decode("utf-8"))
+    return urlsafe_b64decode(base64url_with_padding).decode("utf-8")
+
+
+def base64url_to_bytes(base64url: str) -> bytes:
+    # Addding padding otherwise the following error happens:
+    # binascii.Error: Incorrect padding
+    base64url_with_padding = base64url + "=" * (-len(base64url) % 4)
+    return urlsafe_b64decode(base64url_with_padding)
+
+
+def decode_jwt(token: str) -> Dict[str, Any]:
+    parts = token.split(".")
+    if len(parts) != 3:
+        raise AuthInvalidJwtError("Invalid JWT structure")
+
+    # regex check for base64url
+    if not re.match(BASE64URL_REGEX, parts[1]):
+        raise AuthInvalidJwtError("JWT not in base64url format")
+
+    return {
+        "header": loads(str_from_base64url(parts[0])),
+        "payload": loads(str_from_base64url(parts[1])),
+        "signature": base64url_to_bytes(parts[2]),
+        "raw": {
+            "header": parts[0],
+            "payload": parts[1],
+        },
+    }
 
 
 def generate_pkce_verifier(length=64):
@@ -267,3 +293,19 @@ def is_valid_jwt(value: str) -> bool:
             return False
 
     return True
+
+
+def validate_exp(exp: int) -> None:
+    if not exp:
+        raise AuthInvalidJwtError("JWT has no expiration time")
+
+    time_now = datetime.now().timestamp()
+    if exp <= time_now:
+        raise AuthInvalidJwtError("JWT has expired")
+
+
+def get_algorithm(alg: Literal["RS256", "ES256"]) -> jwt.algorithms.Algorithm:
+    if alg == "RS256":
+        return jwt.algorithms.RSAAlgorithm
+    elif alg == "ES256":
+        return jwt.algorithms.ECAlgorithm
diff --git a/supabase_auth/types.py b/supabase_auth/types.py
index 86bda3e2..dd499aa0 100644
--- a/supabase_auth/types.py
+++ b/supabase_auth/types.py
@@ -789,6 +789,36 @@ class SignOutOptions(TypedDict):
     scope: NotRequired[SignOutScope]
 
 
+class JWTHeader(TypedDict):
+    alg: Literal["RS256", "ES256", "HS256"]
+    typ: str
+    kid: str
+
+
+class RequiredClaims(TypedDict):
+    iss: str
+    sub: str
+    auth: Union[str, List[str]]
+    exp: int
+    iat: int
+    role: str
+    aal: AuthenticatorAssuranceLevels
+    session_id: str
+
+
+class JWTPayload(RequiredClaims, TypedDict, total=False):
+    pass
+
+
+class JWK(TypedDict, total=False):
+    kty: Literal["RSA", "EC", "oct"]
+    key_ops: List[str]
+    alg: Optional[str]
+    kid: Optional[str]
+
+class JWKS(TypedDict):
+    keys: List[JWK]
+
 for model in [
     AMREntry,
     AuthResponse,

From 023d610f78b097bb8dc2937cc9d5c4ed5c18be51 Mon Sep 17 00:00:00 2001
From: Guilherme Souza <guilherme@supabase.io>
Date: Tue, 25 Feb 2025 15:31:01 -0300
Subject: [PATCH 02/15] sync infra

---
 infra/db/00-schema.sql   |  1 +
 infra/docker-compose.yml | 53 +++++++++++++++++++++++++++++++---------
 2 files changed, 43 insertions(+), 11 deletions(-)

diff --git a/infra/db/00-schema.sql b/infra/db/00-schema.sql
index 229af46f..ba844794 100644
--- a/infra/db/00-schema.sql
+++ b/infra/db/00-schema.sql
@@ -83,3 +83,4 @@ GRANT ALL PRIVILEGES ON SCHEMA auth TO postgres;
 GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA auth TO postgres;
 GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA auth TO postgres;
 ALTER USER postgres SET search_path = "auth";
+
diff --git a/infra/docker-compose.yml b/infra/docker-compose.yml
index a4cf04c7..29d4d1ed 100644
--- a/infra/docker-compose.yml
+++ b/infra/docker-compose.yml
@@ -2,7 +2,7 @@
 version: '3'
 services:
   gotrue: # Signup enabled, autoconfirm off
-    image: supabase/auth:v2.132.3
+    image: supabase/auth:v2.151.0
     ports:
       - '9999:9999'
     environment:
@@ -32,21 +32,22 @@ services:
       GOTRUE_SMTP_ADMIN_EMAIL: admin@email.com
       GOTRUE_MAILER_SUBJECTS_CONFIRMATION: 'Please confirm'
       GOTRUE_EXTERNAL_PHONE_ENABLED: 'true'
-      GOTRUE_SMS_PROVIDER: "twilio"
-      GOTRUE_SMS_TWILIO_ACCOUNT_SID: "${GOTRUE_SMS_TWILIO_ACCOUNT_SID}"
-      GOTRUE_SMS_TWILIO_AUTH_TOKEN: "${GOTRUE_SMS_TWILIO_AUTH_TOKEN}"
-      GOTRUE_SMS_TWILIO_MESSAGE_SERVICE_SID: "${GOTRUE_SMS_TWILIO_MESSAGE_SERVICE_SID}"
+      GOTRUE_SMS_PROVIDER: 'twilio'
+      GOTRUE_SMS_TWILIO_ACCOUNT_SID: '${GOTRUE_SMS_TWILIO_ACCOUNT_SID}'
+      GOTRUE_SMS_TWILIO_AUTH_TOKEN: '${GOTRUE_SMS_TWILIO_AUTH_TOKEN}'
+      GOTRUE_SMS_TWILIO_MESSAGE_SERVICE_SID: '${GOTRUE_SMS_TWILIO_MESSAGE_SERVICE_SID}'
       GOTRUE_SMS_AUTOCONFIRM: 'false'
-      GOTRUE_COOKIE_KEY: "sb"
+      GOTRUE_COOKIE_KEY: 'sb'
     depends_on:
       - db
     restart: on-failure
   autoconfirm: # Signup enabled, autoconfirm on
-    image: supabase/auth:v2.132.3
+    image: supabase/auth:v2.151.0
     ports:
       - '9998:9998'
     environment:
       GOTRUE_JWT_SECRET: '37c304f8-51aa-419a-a1af-06154e63707a'
+      GOTRUE_JWT_KEYS: '[{"kty":"oct","k":"Z7-AyPyChGNcQsX16cPBV-pPBo4q-zckDxkq1VZjATo","kid":"12580317-221c-49b6-894a-f4473b8afe39","key_ops":["sign", "verify"],"alg":"HS256"},{"kty":"RSA","n":"y3KQnIXK6wkPQ5m0XWp7z54BNZzXJk4IxXy81zFophdBBqz6u5OCMqWkC6i3WB7rlax4xjmxxyGyYRODooqCQTGahmpXryAAKc3g-gDIAq2MqVwlpmvXDavCVRK4hK7DZ6wK4MHrliSNHCuCkwIH3ofxTxgUwpSkOT58iU1ZOua5E1Y6R_Ozt3gLHha0Xa7a4V23pkP7n0xBvJPzIqiS3MZ4CQ_pz-buXYRgCPQkUJvXFFcuxmyqoYzorwQ1YVBOmH2XMx26RrCIxgj7geo9eVQ9u5qCPpQCGV5biqYMC4_m1kurOGf62URGRzXtmVzrW1PZJAeGoqMz5Fcfr8hiwQ","e":"AQAB","d":"C4XxquvpEmbw9mM-VAwz9w58Aw1fIkxJMuZdy9KAmue2RyqFCRrRxQycvgxQVi1qKpAaRx_9ccn20IjKa-psdkTY-8QKM2EcoUGH_KEOsxghX3ZYq5RwGdYgq7DjwqAjcTvNYe2Z6mcnlvDf9HOo_nG0uUYj5uGEa7meVCiNZUiSVdNGs-vOTUD8yB5pbZ4ute8ebuUzCWGQ3YwSNoWLa-dbECSO7jeobCapdB52MjEwE3_Ii8BWoySeDP-DEFX_5RTM2Zeh81zXAgmOxpZYTkjMsrznyxxBbXn7CdT8WMEXrreGZwIt3Mu6XpsLF5mwmTQ_ZyoM6tJpn5LeAhnCAQ","p":"6xy1skrnlrGUWtZFSHixn_eRA_O3GXKNBE4wziWodGZaFYsmFijZHbuQT0WFqc0epvLHNdNPvubFrVfV-U7ZIarfSSq6qBwBzDrDQS060MvjJIjrI16pKlx2X727FR1ZuwxT27dNg-wRTgKcZqXEalkvFOTEYBlCtw2-vzI0aRs","q":"3YWwOAs4GRZ9eq_fqNujACWJFyUO9QgEDPDOMg0EZhY7WkAlehTxxVXg65spWnfx_0GSc72I5N5qdbY-yDh2Dl7zIxvwnqZaKMJn4PEFkeAfyg62XlJlkHIwOVSj6vLNUDdDmG7bO2k6MyQ59jeuAemIljf9WhALNy8c9R0K3VM","dp":"KJ4LHcQnAjeng5Hk4kJHnXUtjls6VKEfj5DaiaKj2YgdI_-oEsf3ylUu9yLxloYjN4BVvgzFiBtiJzI3exyOEmzsqj1Bhe1guiGkvcvMj2nJ0fP9e1zNKM5UfPHQMjOh3tigXCLst0-_JZT55BnbNuw1YAytiFSU2_755xoLR-U","dq":"dCP7V-bJ6p1X_FLpOGau9wy262OKi_0_4mj-Mk-Q1tUhGRg4jeEdQRDdc6lN7Rilz-ZZGkVs2FGkD0MVd3PisXYmk2m6pfMhoe0K-WxkNy8Ce7Vq99jLVwgHMIenyS6zZjMTRYAZgPSShu2fVe-rU2VVLyz7r5RpzOzuibRIVfE","qi":"i7ND2teiVLkbaAs6rHfo5DiD1nlsORNYnn8Y_FjF6utb5OUljZ6-5WyEDJN9oIUX8o_Il9E6js-z7nhvPfFZHQN7ZWuYI0rO5qmsCDS9jWJ4GR61SgzZuLT7Jpp_KtwjW70x5wZ1Y-GugOP1Wct1YZWHn5YyLhvO6X_vttSmcS0","kid":"638c54b8-28c2-4b12-9598-ba12ef610a29","key_ops":["verify"],"alg":"RS256"}]'
       GOTRUE_JWT_EXP: 3600
       GOTRUE_DB_DRIVER: postgres
       DB_NAMESPACE: auth
@@ -66,12 +67,42 @@ services:
       GOTRUE_SMTP_USER: GOTRUE_SMTP_USER
       GOTRUE_SMTP_PASS: GOTRUE_SMTP_PASS
       GOTRUE_SMTP_ADMIN_EMAIL: admin@email.com
-      GOTRUE_COOKIE_KEY: "sb"
+      GOTRUE_COOKIE_KEY: 'sb'
+    depends_on:
+      - db
+    restart: on-failure
+  autoconfirm_with_asymmetric_keys: # Signup enabled, autoconfirm on
+    image: supabase/auth:v2.169.0
+    ports:
+      - '9996:9996'
+    environment:
+      GOTRUE_JWT_SECRET: 'Z7-AyPyChGNcQsX16cPBV-pPBo4q-zckDxkq1VZjATo'
+      GOTRUE_JWT_KEYS: '[{"kty":"oct","k":"Z7-AyPyChGNcQsX16cPBV-pPBo4q-zckDxkq1VZjATo","kid":"12580317-221c-49b6-894a-f4473b8afe39","key_ops":["verify"],"alg":"HS256"},{"kty":"RSA","n":"y3KQnIXK6wkPQ5m0XWp7z54BNZzXJk4IxXy81zFophdBBqz6u5OCMqWkC6i3WB7rlax4xjmxxyGyYRODooqCQTGahmpXryAAKc3g-gDIAq2MqVwlpmvXDavCVRK4hK7DZ6wK4MHrliSNHCuCkwIH3ofxTxgUwpSkOT58iU1ZOua5E1Y6R_Ozt3gLHha0Xa7a4V23pkP7n0xBvJPzIqiS3MZ4CQ_pz-buXYRgCPQkUJvXFFcuxmyqoYzorwQ1YVBOmH2XMx26RrCIxgj7geo9eVQ9u5qCPpQCGV5biqYMC4_m1kurOGf62URGRzXtmVzrW1PZJAeGoqMz5Fcfr8hiwQ","e":"AQAB","d":"C4XxquvpEmbw9mM-VAwz9w58Aw1fIkxJMuZdy9KAmue2RyqFCRrRxQycvgxQVi1qKpAaRx_9ccn20IjKa-psdkTY-8QKM2EcoUGH_KEOsxghX3ZYq5RwGdYgq7DjwqAjcTvNYe2Z6mcnlvDf9HOo_nG0uUYj5uGEa7meVCiNZUiSVdNGs-vOTUD8yB5pbZ4ute8ebuUzCWGQ3YwSNoWLa-dbECSO7jeobCapdB52MjEwE3_Ii8BWoySeDP-DEFX_5RTM2Zeh81zXAgmOxpZYTkjMsrznyxxBbXn7CdT8WMEXrreGZwIt3Mu6XpsLF5mwmTQ_ZyoM6tJpn5LeAhnCAQ","p":"6xy1skrnlrGUWtZFSHixn_eRA_O3GXKNBE4wziWodGZaFYsmFijZHbuQT0WFqc0epvLHNdNPvubFrVfV-U7ZIarfSSq6qBwBzDrDQS060MvjJIjrI16pKlx2X727FR1ZuwxT27dNg-wRTgKcZqXEalkvFOTEYBlCtw2-vzI0aRs","q":"3YWwOAs4GRZ9eq_fqNujACWJFyUO9QgEDPDOMg0EZhY7WkAlehTxxVXg65spWnfx_0GSc72I5N5qdbY-yDh2Dl7zIxvwnqZaKMJn4PEFkeAfyg62XlJlkHIwOVSj6vLNUDdDmG7bO2k6MyQ59jeuAemIljf9WhALNy8c9R0K3VM","dp":"KJ4LHcQnAjeng5Hk4kJHnXUtjls6VKEfj5DaiaKj2YgdI_-oEsf3ylUu9yLxloYjN4BVvgzFiBtiJzI3exyOEmzsqj1Bhe1guiGkvcvMj2nJ0fP9e1zNKM5UfPHQMjOh3tigXCLst0-_JZT55BnbNuw1YAytiFSU2_755xoLR-U","dq":"dCP7V-bJ6p1X_FLpOGau9wy262OKi_0_4mj-Mk-Q1tUhGRg4jeEdQRDdc6lN7Rilz-ZZGkVs2FGkD0MVd3PisXYmk2m6pfMhoe0K-WxkNy8Ce7Vq99jLVwgHMIenyS6zZjMTRYAZgPSShu2fVe-rU2VVLyz7r5RpzOzuibRIVfE","qi":"i7ND2teiVLkbaAs6rHfo5DiD1nlsORNYnn8Y_FjF6utb5OUljZ6-5WyEDJN9oIUX8o_Il9E6js-z7nhvPfFZHQN7ZWuYI0rO5qmsCDS9jWJ4GR61SgzZuLT7Jpp_KtwjW70x5wZ1Y-GugOP1Wct1YZWHn5YyLhvO6X_vttSmcS0","kid":"638c54b8-28c2-4b12-9598-ba12ef610a29","key_ops":["sign","verify"],"alg":"RS256"}]'
+      GOTRUE_JWT_EXP: 3600
+      GOTRUE_DB_DRIVER: postgres
+      DB_NAMESPACE: auth
+      GOTRUE_API_HOST: 0.0.0.0
+      PORT: 9996
+      GOTRUE_DISABLE_SIGNUP: 'false'
+      API_EXTERNAL_URL: http://localhost:9996
+      GOTRUE_SITE_URL: http://localhost:9996
+      GOTRUE_MAILER_AUTOCONFIRM: 'true'
+      GOTRUE_SMS_AUTOCONFIRM: 'true'
+      GOTRUE_LOG_LEVEL: DEBUG
+      GOTRUE_OPERATOR_TOKEN: super-secret-operator-token
+      DATABASE_URL: 'postgres://postgres:postgres@db:5432/postgres?sslmode=disable'
+      GOTRUE_EXTERNAL_PHONE_ENABLED: 'true'
+      GOTRUE_SMTP_HOST: mail
+      GOTRUE_SMTP_PORT: 2500
+      GOTRUE_SMTP_USER: GOTRUE_SMTP_USER
+      GOTRUE_SMTP_PASS: GOTRUE_SMTP_PASS
+      GOTRUE_SMTP_ADMIN_EMAIL: admin@email.com
+      GOTRUE_COOKIE_KEY: 'sb'
     depends_on:
       - db
     restart: on-failure
   disabled: # Signup disabled
-    image: supabase/auth:v2.132.3
+    image: supabase/auth:v2.151.0
     ports:
       - '9997:9997'
     environment:
@@ -95,7 +126,7 @@ services:
       GOTRUE_SMTP_USER: GOTRUE_SMTP_USER
       GOTRUE_SMTP_PASS: GOTRUE_SMTP_PASS
       GOTRUE_SMTP_ADMIN_EMAIL: admin@email.com
-      GOTRUE_COOKIE_KEY: "sb"
+      GOTRUE_COOKIE_KEY: 'sb'
     depends_on:
       - db
     restart: on-failure
@@ -106,7 +137,7 @@ services:
       - '9000:9000' # web interface
       - '1100:1100' # POP3
   db:
-    image: supabase/postgres:15.1.1.66
+    image: supabase/postgres:15.1.1.46
     ports:
       - '5432:5432'
     command: postgres -c config_file=/etc/postgresql/postgresql.conf

From faaf88c654f4358c870c2bd268d51771184c9968 Mon Sep 17 00:00:00 2001
From: Guilherme Souza <guilherme@supabase.io>
Date: Tue, 25 Feb 2025 15:31:11 -0300
Subject: [PATCH 03/15] fix tests

---
 supabase_auth/errors.py  |  3 ++-
 supabase_auth/helpers.py |  2 +-
 tests/test_helpers.py    | 12 ++++++------
 3 files changed, 9 insertions(+), 8 deletions(-)

diff --git a/supabase_auth/errors.py b/supabase_auth/errors.py
index 1cf8de3c..6764aa82 100644
--- a/supabase_auth/errors.py
+++ b/supabase_auth/errors.py
@@ -232,4 +232,5 @@ def __init__(self, message: str) -> None:
             self,
             message,
             "AuthInvalidJwtError",
-            "invalid_jwt",
\ No newline at end of file
+            "invalid_jwt",
+        )
\ No newline at end of file
diff --git a/supabase_auth/helpers.py b/supabase_auth/helpers.py
index 8be212e2..838d16d5 100644
--- a/supabase_auth/helpers.py
+++ b/supabase_auth/helpers.py
@@ -8,7 +8,7 @@
 from base64 import urlsafe_b64decode
 from datetime import datetime
 from json import loads
-from typing import Any, Callable, Dict, Literal, Optional, Type, TypeVar, cast
+from typing import Any, Dict, Literal, Optional, Type, TypeVar, cast
 from urllib.parse import urlparse
 
 from httpx import HTTPStatusError, Response
diff --git a/tests/test_helpers.py b/tests/test_helpers.py
index 57dcb66a..4132e9e6 100644
--- a/tests/test_helpers.py
+++ b/tests/test_helpers.py
@@ -6,9 +6,9 @@
 from httpx import Headers, Response
 
 from supabase_auth.constants import API_VERSION_HEADER_NAME
-from supabase_auth.errors import AuthApiError, AuthWeakPasswordError
+from supabase_auth.errors import AuthApiError, AuthInvalidJwtError, AuthWeakPasswordError
 from supabase_auth.helpers import (
-    decode_jwt_payload,
+    decode_jwt,
     generate_pkce_challenge,
     generate_pkce_verifier,
     get_error_code,
@@ -111,13 +111,13 @@ def test_get_error_code():
     assert get_error_code({"error_code": "500"}) == "500"
 
 
-def test_decode_jwt_payload():
-    assert decode_jwt_payload(mock_access_token())
+def test_decode_jwt():
+    assert decode_jwt(mock_access_token())
 
     with pytest.raises(
-        ValueError, match=r"JWT is not valid: not a JWT structure"
+        AuthInvalidJwtError, match=r"Invalid JWT structure"
     ) as exc:
-        decode_jwt_payload("non-valid-jwt")
+        decode_jwt("non-valid-jwt")
     assert exc.value is not None
 
 

From 8ac31322ed0111fd81ae9d5ae9d2a521ca70be53 Mon Sep 17 00:00:00 2001
From: Guilherme Souza <guilherme@supabase.io>
Date: Tue, 25 Feb 2025 15:51:35 -0300
Subject: [PATCH 04/15] fix decode_jwt tests

---
 Makefile                             |  2 +-
 poetry.lock                          | 60 +++++++++++++++--------
 pyproject.toml                       |  6 +--
 supabase_auth/_sync/gotrue_client.py | 61 ++++++++++++++++++++---
 supabase_auth/errors.py              |  1 +
 supabase_auth/helpers.py             | 32 +++++++-----
 tests/_sync/test_gotrue_admin_api.py | 73 +++++++++++++++-------------
 7 files changed, 161 insertions(+), 74 deletions(-)

diff --git a/Makefile b/Makefile
index 69857fde..5289f26c 100644
--- a/Makefile
+++ b/Makefile
@@ -38,7 +38,7 @@ build_run_tests: build_sync run_tests
 	echo "Done"
 
 sleep:
-	sleep 20
+	sleep 3
 
 rename_project: rename_package_dir rename_package
 
diff --git a/poetry.lock b/poetry.lock
index d3c96613..d6f729f7 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1221,18 +1221,19 @@ httpx = ">=0.25.0"
 
 [[package]]
 name = "setuptools"
-version = "58.5.3"
+version = "72.2.0"
 description = "Easily download, build, install, upgrade, and uninstall Python packages"
 optional = false
-python-versions = ">=3.6"
+python-versions = ">=3.8"
 files = [
-    {file = "setuptools-58.5.3-py3-none-any.whl", hash = "sha256:a481fbc56b33f5d8f6b33dce41482e64c68b668be44ff42922903b03872590bf"},
-    {file = "setuptools-58.5.3.tar.gz", hash = "sha256:dae6b934a965c8a59d6d230d3867ec408bb95e73bd538ff77e71fedf1eaca729"},
+    {file = "setuptools-72.2.0-py3-none-any.whl", hash = "sha256:f11dd94b7bae3a156a95ec151f24e4637fb4fa19c878e4d191bfb8b2d82728c4"},
+    {file = "setuptools-72.2.0.tar.gz", hash = "sha256:80aacbf633704e9c8bfa1d99fa5dd4dc59573efcf9e4042c13d3bcef91ac2ef9"},
 ]
 
 [package.extras]
-docs = ["furo", "jaraco.packaging (>=8.2)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx", "sphinx-inline-tabs", "sphinxcontrib-towncrier"]
-testing = ["flake8-2020", "jaraco.envs", "jaraco.path (>=3.2.0)", "mock", "paver", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.0.1)", "pytest-flake8", "pytest-mypy", "pytest-virtualenv (>=1.2.7)", "pytest-xdist", "sphinx", "virtualenv (>=13.0.0)", "wheel"]
+core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.text (>=3.7)", "more-itertools (>=8.8)", "ordered-set (>=3.1.1)", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"]
+doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
+test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "mypy (==1.11.*)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (<0.4)", "pytest-ruff (>=0.2.1)", "pytest-ruff (>=0.3.2)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
 
 [[package]]
 name = "sniffio"
@@ -1245,6 +1246,17 @@ files = [
     {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"},
 ]
 
+[[package]]
+name = "tokenize-rt"
+version = "6.1.0"
+description = "A wrapper around the stdlib `tokenize` which roundtrips."
+optional = false
+python-versions = ">=3.9"
+files = [
+    {file = "tokenize_rt-6.1.0-py2.py3-none-any.whl", hash = "sha256:d706141cdec4aa5f358945abe36b911b8cbdc844545da99e811250c0cee9b6fc"},
+    {file = "tokenize_rt-6.1.0.tar.gz", hash = "sha256:e8ee836616c0877ab7c7b54776d2fefcc3bde714449a206762425ae114b53c86"},
+]
+
 [[package]]
 name = "tomli"
 version = "2.0.2"
@@ -1300,30 +1312,38 @@ files = [
 
 [[package]]
 name = "unasync"
-version = "0.5.0"
+version = "0.6.0"
 description = "The async transformation code."
 optional = false
-python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4"
+python-versions = ">=3.8"
 files = [
-    {file = "unasync-0.5.0-py3-none-any.whl", hash = "sha256:8d4536dae85e87b8751dfcc776f7656fd0baf54bb022a7889440dc1b9dc3becb"},
-    {file = "unasync-0.5.0.tar.gz", hash = "sha256:b675d87cf56da68bd065d3b7a67ac71df85591978d84c53083c20d79a7e5096d"},
+    {file = "unasync-0.6.0-py3-none-any.whl", hash = "sha256:9cf7aaaea9737e417d8949bf9be55dc25fdb4ef1f4edc21b58f76ff0d2b9d73f"},
+    {file = "unasync-0.6.0.tar.gz", hash = "sha256:a9d01ace3e1068b20550ab15b7f9723b15b8bcde728bc1770bcb578374c7ee58"},
 ]
 
+[package.dependencies]
+setuptools = "*"
+tokenize-rt = "*"
+
 [[package]]
 name = "unasync-cli"
-version = "0.0.9"
-description = "Command line interface for unasync"
+version = "0.0.1"
+description = "Command line interface for unasync. Fork of https://github.com/leynier/unasync-cli/"
 optional = false
-python-versions = ">=3.6.14,<4.0.0"
-files = [
-    {file = "unasync-cli-0.0.9.tar.gz", hash = "sha256:ca9d8c57ebb68911f8f8f68f243c7f6d0bb246ee3fd14743bc51c8317e276554"},
-    {file = "unasync_cli-0.0.9-py3-none-any.whl", hash = "sha256:f96c42fb2862efa555ce6d6415a5983ceb162aa0e45be701656d20a955c7c540"},
-]
+python-versions = "^3.8.18"
+files = []
+develop = false
 
 [package.dependencies]
-setuptools = ">=58.2.0,<59.0.0"
-typer = ">=0.4.0,<0.5.0"
-unasync = ">=0.5.0,<0.6.0"
+setuptools = "^72.1.0"
+typer = "^0.4.0"
+unasync = "^0.6.0"
+
+[package.source]
+type = "git"
+url = "https://github.com/supabase-community/unasync-cli.git"
+reference = "main"
+resolved_reference = "6a082ee36d5e8941622b70f6cbcaf8e7a5be339d"
 
 [[package]]
 name = "urllib3"
diff --git a/pyproject.toml b/pyproject.toml
index d6873904..09c7c90d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -11,12 +11,12 @@ license = "MIT"
 classifiers = [
     "Programming Language :: Python :: 3",
     "License :: OSI Approved :: MIT License",
-    "Operating System :: OS Independent"
+    "Operating System :: OS Independent",
 ]
 
 [tool.poetry.dependencies]
 python = "^3.9"
-httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
+httpx = { version = ">=0.26,<0.29", extras = ["http2"] }
 pydantic = ">=1.10,<3"
 pyjwt = "^2.10.1"
 
@@ -30,7 +30,7 @@ pytest-cov = "^6.0.0"
 pytest-depends = "^1.0.1"
 pytest-asyncio = "^0.25.3"
 Faker = "^36.1.1"
-unasync-cli = "^0.0.9"
+unasync-cli = { git = "https://github.com/supabase-community/unasync-cli.git", branch = "main" }
 
 [tool.poetry.group.dev.dependencies]
 pygithub = ">=1.57,<3.0"
diff --git a/supabase_auth/_sync/gotrue_client.py b/supabase_auth/_sync/gotrue_client.py
index e4e821ef..db6040b0 100644
--- a/supabase_auth/_sync/gotrue_client.py
+++ b/supabase_auth/_sync/gotrue_client.py
@@ -4,7 +4,7 @@
 from functools import partial
 from json import loads
 from time import time
-from typing import Callable, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
 from urllib.parse import parse_qs, urlencode, urlparse
 from uuid import uuid4
 
@@ -20,11 +20,12 @@
     AuthApiError,
     AuthImplicitGrantRedirectError,
     AuthInvalidCredentialsError,
+    AuthInvalidJwtError,
     AuthRetryableError,
     AuthSessionMissingError,
 )
 from ..helpers import (
-    decode_jwt_payload,
+    decode_jwt,
     generate_pkce_challenge,
     generate_pkce_verifier,
     model_dump,
@@ -39,6 +40,8 @@
 from ..http_clients import SyncClient
 from ..timer import Timer
 from ..types import (
+    JWK,
+    JWKS,
     AuthChangeEvent,
     AuthenticatorAssuranceLevels,
     AuthFlowType,
@@ -106,6 +109,7 @@ def __init__(
             verify=verify,
             proxy=proxy,
         )
+        self._jwks: JWKS = {}
         self._storage_key = storage_key or STORAGE_KEY
         self._auto_refresh_token = auto_refresh_token
         self._persist_session = persist_session
@@ -415,7 +419,9 @@ def sign_in_with_oauth(
         )
         return OAuthResponse(provider=provider, url=url_with_qs)
 
-    def link_identity(self, credentials: SignInWithOAuthCredentials) -> OAuthResponse:
+    def link_identity(
+        self, credentials: SignInWithOAuthCredentials
+    ) -> OAuthResponse:
         provider = credentials.get("provider")
         options = credentials.get("options", {})
         redirect_to = options.get("redirect_to")
@@ -698,7 +704,9 @@ def set_session(self, access_token: str, refresh_token: str) -> AuthResponse:
         self._notify_all_subscribers("TOKEN_REFRESHED", session)
         return AuthResponse(session=session, user=response.user)
 
-    def refresh_session(self, refresh_token: Optional[str] = None) -> AuthResponse:
+    def refresh_session(
+        self, refresh_token: Optional[str] = None
+    ) -> AuthResponse:
         """
         Returns a new session, regardless of expiry status.
 
@@ -1107,7 +1115,9 @@ def _get_url_for_provider(
         if self._flow_type == "pkce":
             code_verifier = generate_pkce_verifier()
             code_challenge = generate_pkce_challenge(code_verifier)
-            self._storage.set_item(f"{self._storage_key}-code-verifier", code_verifier)
+            self._storage.set_item(
+                f"{self._storage_key}-code-verifier", code_verifier
+            )
             code_challenge_method = (
                 "plain" if code_verifier == code_challenge else "s256"
             )
@@ -1122,7 +1132,8 @@ def _decode_jwt(self, jwt: str) -> DecodedJWTDict:
         """
         Decodes a JWT (without performing any validation).
         """
-        return decode_jwt_payload(jwt)
+        decoded = decode_jwt(jwt)
+        return decoded["payload"]
 
     def exchange_code_for_session(self, params: CodeExchangeParams):
         code_verifier = params.get("code_verifier") or self._storage.get_item(
@@ -1144,3 +1155,41 @@ def exchange_code_for_session(self, params: CodeExchangeParams):
             self._save_session(response.session)
             self._notify_all_subscribers("SIGNED_IN", response.session)
         return response
+
+    def _fetch_jwks(self, kid: str, jwks: JWKS) -> JWK:
+        # try fetching from the suplied keys.
+        jwk = next((jwk for jwk in jwks["keys"] if jwk["kid"] == kid), None)
+
+        if jwk:
+            return jwk
+
+        # try fetching from the cache.
+        jwk = next((jwk for jwk in self._jwks["keys"] if jwk["kid"] == kid), None)
+        if jwk:
+            return jwk
+
+        # jwk isn't cached in memory so we need to fetch it from the well-known endpoint
+        response = self._request("GET", ".well-known/jwks.json", xform=parse_jwks)
+        if response.jwks:
+            self._jwks = response.jwks
+
+            # find the signing key
+            jwk = next(
+                (jwk for jwk in response.jwks["keys"] if jwk["kid"] == kid), None
+            )
+            if not jwk:
+                raise AuthInvalidJwtError("No matching signing key found in JWKS")
+
+            return jwk
+
+        raise AuthInvalidJwtError("JWT has no valid kid")
+
+    def get_claims(self):
+        pass
+
+
+def parse_jwks(response: Any) -> JWKS:
+    if "keys" not in response or len(response.keys) == 0:
+        raise AuthInvalidJwtError("JWKS is empty")
+
+    return JWKS(keys=response.keys)
diff --git a/supabase_auth/errors.py b/supabase_auth/errors.py
index 6764aa82..a26f939b 100644
--- a/supabase_auth/errors.py
+++ b/supabase_auth/errors.py
@@ -232,5 +232,6 @@ def __init__(self, message: str) -> None:
             self,
             message,
             "AuthInvalidJwtError",
+            400,
             "invalid_jwt",
         )
\ No newline at end of file
diff --git a/supabase_auth/helpers.py b/supabase_auth/helpers.py
index 838d16d5..1a7c48fb 100644
--- a/supabase_auth/helpers.py
+++ b/supabase_auth/helpers.py
@@ -8,7 +8,7 @@
 from base64 import urlsafe_b64decode
 from datetime import datetime
 from json import loads
-from typing import Any, Dict, Literal, Optional, Type, TypeVar, cast
+from typing import Any, Dict, Literal, Optional, Type, TypeVar, TypedDict, cast
 from urllib.parse import urlparse
 
 from httpx import HTTPStatusError, Response
@@ -30,6 +30,8 @@
     AuthResponse,
     GenerateLinkProperties,
     GenerateLinkResponse,
+    JWTHeader,
+    JWTPayload,
     LinkIdentityResponse,
     Session,
     SSOResponse,
@@ -209,24 +211,32 @@ def base64url_to_bytes(base64url: str) -> bytes:
     return urlsafe_b64decode(base64url_with_padding)
 
 
-def decode_jwt(token: str) -> Dict[str, Any]:
+class DecodedJWT(TypedDict):
+    header: JWTHeader
+    payload: JWTPayload
+    signature: bytes
+    raw: Dict[str, str]
+
+
+def decode_jwt(token: str) -> DecodedJWT:
     parts = token.split(".")
     if len(parts) != 3:
         raise AuthInvalidJwtError("Invalid JWT structure")
 
     # regex check for base64url
-    if not re.match(BASE64URL_REGEX, parts[1]):
-        raise AuthInvalidJwtError("JWT not in base64url format")
-
-    return {
-        "header": loads(str_from_base64url(parts[0])),
-        "payload": loads(str_from_base64url(parts[1])),
-        "signature": base64url_to_bytes(parts[2]),
-        "raw": {
+    # for part in parts:
+    #     if not re.match(BASE64URL_REGEX, part):
+    #         raise AuthInvalidJwtError("JWT not in base64url format")
+
+    return DecodedJWT(
+        header=JWTHeader(**loads(str_from_base64url(parts[0]))),
+        payload=JWTPayload(**loads(str_from_base64url(parts[1]))),
+        signature=base64url_to_bytes(parts[2]),
+        raw={
             "header": parts[0],
             "payload": parts[1],
         },
-    }
+    )
 
 
 def generate_pkce_verifier(length=64):
diff --git a/tests/_sync/test_gotrue_admin_api.py b/tests/_sync/test_gotrue_admin_api.py
index c885d4bf..c9769f1f 100644
--- a/tests/_sync/test_gotrue_admin_api.py
+++ b/tests/_sync/test_gotrue_admin_api.py
@@ -17,7 +17,6 @@
 )
 from .utils import (
     create_new_user_with_email,
-    mock_access_token,
     mock_app_metadata,
     mock_user_credentials,
     mock_user_metadata,
@@ -31,19 +30,19 @@ def test_create_user_should_create_a_new_user():
     assert response.email == credentials.get("email")
 
 
-def test_create_user_with_app_metadata():
-    app_metadata = mock_app_metadata()
+def test_create_user_with_user_metadata():
+    user_metadata = mock_user_metadata()
     credentials = mock_user_credentials()
     response = service_role_api_client().create_user(
         {
             "email": credentials.get("email"),
             "password": credentials.get("password"),
-            "app_metadata": app_metadata,
+            "user_metadata": user_metadata,
         }
     )
     assert response.user.email == credentials.get("email")
-    assert "provider" in response.user.app_metadata
-    assert "providers" in response.user.app_metadata
+    assert response.user.user_metadata == user_metadata
+    assert "profile_image" in response.user.user_metadata
 
 
 def test_create_user_with_user_and_app_metadata():
@@ -157,7 +156,7 @@ def test_modify_confirm_email_using_update_user_by_id():
 
 def test_invalid_credential_sign_in_with_phone():
     try:
-        client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password(
+        response = client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password(
             {
                 "phone": "+123456789",
                 "password": "strong_pwd",
@@ -169,7 +168,7 @@ def test_invalid_credential_sign_in_with_phone():
 
 def test_invalid_credential_sign_in_with_email():
     try:
-        client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password(
+        response = client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password(
             {
                 "email": "unknown_user@unknowndomain.com",
                 "password": "strong_pwd",
@@ -364,39 +363,47 @@ def test_verify_otp_with_invalid_phone_number():
         assert e.message == "Invalid phone number format (E.164 required)"
 
 
-def test_sign_in_with_oauth():
-    assert client_api_auto_confirm_off_signups_enabled_client().sign_in_with_oauth(
-        {
-            "provider": "google",
-        }
-    )
-
+def test_sign_in_with_id_token():
+    try:
+        client_api_auto_confirm_off_signups_enabled_client().sign_in_with_id_token(
+            {
+                "provider": "google",
+                "token": "123456",
+            }
+        )
+    except AuthApiError as e:
+        assert e.to_dict()
 
-def test_decode_jwt():
-    assert auth_client_with_session()._decode_jwt(mock_access_token())
 
+def test_sign_in_with_sso():
+    with pytest.raises(AuthApiError, match=r"SAML 2.0 is disabled") as exc:
+        client_api_auto_confirm_off_signups_enabled_client().sign_in_with_sso(
+            {
+                "domain": "google",
+            }
+        )
+    assert exc.value is not None
 
-def test_link_identity_missing_session():
 
-    with pytest.raises(AuthSessionMissingError) as exc:
-        client_api_auto_confirm_off_signups_enabled_client().link_identity(
+def test_sign_in_with_oauth():
+    assert (
+        client_api_auto_confirm_off_signups_enabled_client().sign_in_with_oauth(
             {
                 "provider": "google",
             }
         )
-    assert exc.value is not None
+    )
 
 
-def test_sign_in_with_id_token():
-    try:
-        client_api_auto_confirm_off_signups_enabled_client().sign_in_with_id_token(
+def test_link_identity_missing_session():
+
+    with pytest.raises(AuthSessionMissingError) as exc:
+        client_api_auto_confirm_off_signups_enabled_client().link_identity(
             {
                 "provider": "google",
-                "token": "123456",
             }
         )
-    except AuthApiError as e:
-        assert e.to_dict()
+    assert exc.value is not None
 
 
 def test_get_item_from_memory_storage():
@@ -518,7 +525,7 @@ def test_get_user_identities():
             "password": credentials.get("password"),
         }
     )
-    assert client.get_user_identities().identities[0].identity_data[
+    assert (client.get_user_identities()).identities[0].identity_data[
         "email"
     ] == credentials.get("email")
 
@@ -542,19 +549,19 @@ def test_update_user():
     )
 
 
-def test_create_user_with_user_metadata():
-    user_metadata = mock_user_metadata()
+def test_create_user_with_app_metadata():
+    app_metadata = mock_app_metadata()
     credentials = mock_user_credentials()
     response = service_role_api_client().create_user(
         {
             "email": credentials.get("email"),
             "password": credentials.get("password"),
-            "user_metadata": user_metadata,
+            "app_metadata": app_metadata,
         }
     )
     assert response.user.email == credentials.get("email")
-    assert response.user.user_metadata == user_metadata
-    assert "profile_image" in response.user.user_metadata
+    assert "provider" in response.user.app_metadata
+    assert "providers" in response.user.app_metadata
 
 
 def test_weak_email_password_error():

From 66fa83a14d64a9650d7b71427c929460a14a6bda Mon Sep 17 00:00:00 2001
From: Guilherme Souza <guilherme@supabase.io>
Date: Tue, 25 Feb 2025 15:56:17 -0300
Subject: [PATCH 05/15] fix tests returning another error message

---
 infra/db/00-schema.sql                |  1 -
 supabase_auth/_sync/gotrue_client.py  | 12 +++------
 supabase_auth/errors.py               |  3 ++-
 supabase_auth/helpers.py              |  4 +--
 supabase_auth/types.py                |  2 ++
 tests/_async/test_gotrue_admin_api.py |  2 +-
 tests/_sync/test_gotrue_admin_api.py  | 36 ++++++++++++++-------------
 tests/test_helpers.py                 | 10 +++++---
 8 files changed, 35 insertions(+), 35 deletions(-)

diff --git a/infra/db/00-schema.sql b/infra/db/00-schema.sql
index ba844794..229af46f 100644
--- a/infra/db/00-schema.sql
+++ b/infra/db/00-schema.sql
@@ -83,4 +83,3 @@ GRANT ALL PRIVILEGES ON SCHEMA auth TO postgres;
 GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA auth TO postgres;
 GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA auth TO postgres;
 ALTER USER postgres SET search_path = "auth";
-
diff --git a/supabase_auth/_sync/gotrue_client.py b/supabase_auth/_sync/gotrue_client.py
index db6040b0..6e1cbed4 100644
--- a/supabase_auth/_sync/gotrue_client.py
+++ b/supabase_auth/_sync/gotrue_client.py
@@ -419,9 +419,7 @@ def sign_in_with_oauth(
         )
         return OAuthResponse(provider=provider, url=url_with_qs)
 
-    def link_identity(
-        self, credentials: SignInWithOAuthCredentials
-    ) -> OAuthResponse:
+    def link_identity(self, credentials: SignInWithOAuthCredentials) -> OAuthResponse:
         provider = credentials.get("provider")
         options = credentials.get("options", {})
         redirect_to = options.get("redirect_to")
@@ -704,9 +702,7 @@ def set_session(self, access_token: str, refresh_token: str) -> AuthResponse:
         self._notify_all_subscribers("TOKEN_REFRESHED", session)
         return AuthResponse(session=session, user=response.user)
 
-    def refresh_session(
-        self, refresh_token: Optional[str] = None
-    ) -> AuthResponse:
+    def refresh_session(self, refresh_token: Optional[str] = None) -> AuthResponse:
         """
         Returns a new session, regardless of expiry status.
 
@@ -1115,9 +1111,7 @@ def _get_url_for_provider(
         if self._flow_type == "pkce":
             code_verifier = generate_pkce_verifier()
             code_challenge = generate_pkce_challenge(code_verifier)
-            self._storage.set_item(
-                f"{self._storage_key}-code-verifier", code_verifier
-            )
+            self._storage.set_item(f"{self._storage_key}-code-verifier", code_verifier)
             code_challenge_method = (
                 "plain" if code_verifier == code_challenge else "s256"
             )
diff --git a/supabase_auth/errors.py b/supabase_auth/errors.py
index a26f939b..cc85b87e 100644
--- a/supabase_auth/errors.py
+++ b/supabase_auth/errors.py
@@ -226,6 +226,7 @@ def to_dict(self) -> AuthApiErrorDict:
             "reasons": self.reasons,
         }
 
+
 class AuthInvalidJwtError(CustomAuthError):
     def __init__(self, message: str) -> None:
         CustomAuthError.__init__(
@@ -234,4 +235,4 @@ def __init__(self, message: str) -> None:
             "AuthInvalidJwtError",
             400,
             "invalid_jwt",
-        )
\ No newline at end of file
+        )
diff --git a/supabase_auth/helpers.py b/supabase_auth/helpers.py
index 1a7c48fb..57cecf31 100644
--- a/supabase_auth/helpers.py
+++ b/supabase_auth/helpers.py
@@ -8,12 +8,12 @@
 from base64 import urlsafe_b64decode
 from datetime import datetime
 from json import loads
-from typing import Any, Dict, Literal, Optional, Type, TypeVar, TypedDict, cast
+from typing import Any, Dict, Literal, Optional, Type, TypedDict, TypeVar, cast
 from urllib.parse import urlparse
 
-from httpx import HTTPStatusError, Response
 import jwt
 import jwt.algorithms
+from httpx import HTTPStatusError, Response
 from pydantic import BaseModel
 
 from .constants import API_VERSION_HEADER_NAME, API_VERSIONS
diff --git a/supabase_auth/types.py b/supabase_auth/types.py
index dd499aa0..5425173a 100644
--- a/supabase_auth/types.py
+++ b/supabase_auth/types.py
@@ -816,9 +816,11 @@ class JWK(TypedDict, total=False):
     alg: Optional[str]
     kid: Optional[str]
 
+
 class JWKS(TypedDict):
     keys: List[JWK]
 
+
 for model in [
     AMREntry,
     AuthResponse,
diff --git a/tests/_async/test_gotrue_admin_api.py b/tests/_async/test_gotrue_admin_api.py
index 701f46c8..dc96a145 100644
--- a/tests/_async/test_gotrue_admin_api.py
+++ b/tests/_async/test_gotrue_admin_api.py
@@ -344,7 +344,7 @@ async def test_verify_otp_with_non_existent_phone_number():
         )
         assert False
     except AuthError as e:
-        assert e.message == "User not found"
+        assert e.message == "Token has expired or is invalid"
 
 
 async def test_verify_otp_with_invalid_phone_number():
diff --git a/tests/_sync/test_gotrue_admin_api.py b/tests/_sync/test_gotrue_admin_api.py
index c9769f1f..dd5fa523 100644
--- a/tests/_sync/test_gotrue_admin_api.py
+++ b/tests/_sync/test_gotrue_admin_api.py
@@ -156,11 +156,13 @@ def test_modify_confirm_email_using_update_user_by_id():
 
 def test_invalid_credential_sign_in_with_phone():
     try:
-        response = client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password(
-            {
-                "phone": "+123456789",
-                "password": "strong_pwd",
-            }
+        response = (
+            client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password(
+                {
+                    "phone": "+123456789",
+                    "password": "strong_pwd",
+                }
+            )
         )
     except AuthApiError as e:
         assert e.to_dict()
@@ -168,11 +170,13 @@ def test_invalid_credential_sign_in_with_phone():
 
 def test_invalid_credential_sign_in_with_email():
     try:
-        response = client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password(
-            {
-                "email": "unknown_user@unknowndomain.com",
-                "password": "strong_pwd",
-            }
+        response = (
+            client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password(
+                {
+                    "email": "unknown_user@unknowndomain.com",
+                    "password": "strong_pwd",
+                }
+            )
         )
     except AuthApiError as e:
         assert e.to_dict()
@@ -344,7 +348,7 @@ def test_verify_otp_with_non_existent_phone_number():
         )
         assert False
     except AuthError as e:
-        assert e.message == "User not found"
+        assert e.message == "Token has expired or is invalid"
 
 
 def test_verify_otp_with_invalid_phone_number():
@@ -386,12 +390,10 @@ def test_sign_in_with_sso():
 
 
 def test_sign_in_with_oauth():
-    assert (
-        client_api_auto_confirm_off_signups_enabled_client().sign_in_with_oauth(
-            {
-                "provider": "google",
-            }
-        )
+    assert client_api_auto_confirm_off_signups_enabled_client().sign_in_with_oauth(
+        {
+            "provider": "google",
+        }
     )
 
 
diff --git a/tests/test_helpers.py b/tests/test_helpers.py
index 4132e9e6..7fb27809 100644
--- a/tests/test_helpers.py
+++ b/tests/test_helpers.py
@@ -6,7 +6,11 @@
 from httpx import Headers, Response
 
 from supabase_auth.constants import API_VERSION_HEADER_NAME
-from supabase_auth.errors import AuthApiError, AuthInvalidJwtError, AuthWeakPasswordError
+from supabase_auth.errors import (
+    AuthApiError,
+    AuthInvalidJwtError,
+    AuthWeakPasswordError,
+)
 from supabase_auth.helpers import (
     decode_jwt,
     generate_pkce_challenge,
@@ -114,9 +118,7 @@ def test_get_error_code():
 def test_decode_jwt():
     assert decode_jwt(mock_access_token())
 
-    with pytest.raises(
-        AuthInvalidJwtError, match=r"Invalid JWT structure"
-    ) as exc:
+    with pytest.raises(AuthInvalidJwtError, match=r"Invalid JWT structure") as exc:
         decode_jwt("non-valid-jwt")
     assert exc.value is not None
 

From 4ae36c4110a74589e331b25baae07411df913dcd Mon Sep 17 00:00:00 2001
From: Guilherme Souza <guilherme@supabase.io>
Date: Mon, 10 Mar 2025 18:15:31 +0100
Subject: [PATCH 06/15] wip tests

---
 poetry.lock                           | 17 ++++++++++++
 pyproject.toml                        |  1 +
 supabase_auth/_async/gotrue_client.py | 39 ++++++++++++++++++++++-----
 supabase_auth/constants.py            |  1 +
 supabase_auth/helpers.py              |  3 +--
 supabase_auth/types.py                |  9 +++++--
 tests/_async/test_gotrue.py           | 24 +++++++++++++++++
 7 files changed, 84 insertions(+), 10 deletions(-)
 create mode 100644 tests/_async/test_gotrue.py

diff --git a/poetry.lock b/poetry.lock
index d6f729f7..1020f882 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1122,6 +1122,23 @@ future-fstrings = "*"
 networkx = "*"
 pytest = ">=3"
 
+[[package]]
+name = "pytest-mock"
+version = "3.14.0"
+description = "Thin-wrapper around the mock package for easier use with pytest"
+optional = false
+python-versions = ">=3.8"
+files = [
+    {file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"},
+    {file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"},
+]
+
+[package.dependencies]
+pytest = ">=6.2.5"
+
+[package.extras]
+dev = ["pre-commit", "pytest-asyncio", "tox"]
+
 [[package]]
 name = "pyyaml"
 version = "6.0.2"
diff --git a/pyproject.toml b/pyproject.toml
index 09c7c90d..deaf4c90 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -19,6 +19,7 @@ python = "^3.9"
 httpx = { version = ">=0.26,<0.29", extras = ["http2"] }
 pydantic = ">=1.10,<3"
 pyjwt = "^2.10.1"
+pytest-mock = "^3.14.0"
 
 [tool.poetry.dev-dependencies]
 pytest = "^8.3.5"
diff --git a/supabase_auth/_async/gotrue_client.py b/supabase_auth/_async/gotrue_client.py
index 4069991b..a99a1133 100644
--- a/supabase_auth/_async/gotrue_client.py
+++ b/supabase_auth/_async/gotrue_client.py
@@ -4,7 +4,7 @@
 from functools import partial
 from json import loads
 from time import time
-from typing import Any, Callable, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple, TypedDict
 from urllib.parse import parse_qs, urlencode, urlparse
 from uuid import uuid4
 
@@ -28,6 +28,7 @@
     decode_jwt,
     generate_pkce_challenge,
     generate_pkce_verifier,
+    get_algorithm,
     model_dump,
     model_dump_json,
     model_validate,
@@ -36,6 +37,7 @@
     parse_link_identity_response,
     parse_sso_response,
     parse_user_response,
+    validate_exp,
 )
 from ..http_clients import AsyncClient
 from ..timer import Timer
@@ -53,9 +55,12 @@
     AuthMFAVerifyResponse,
     AuthOtpResponse,
     AuthResponse,
+    ClaimsResponse,
     CodeExchangeParams,
     DecodedJWTDict,
     IdentitiesResponse,
+    JWTHeader,
+    JWTPayload,
     MFAChallengeAndVerifyParams,
     MFAChallengeParams,
     MFAEnrollParams,
@@ -1158,13 +1163,13 @@ async def exchange_code_for_session(self, params: CodeExchangeParams):
 
     async def _fetch_jwks(self, kid: str, jwks: JWKS) -> JWK:
         # try fetching from the suplied keys.
-        jwk = next((jwk for jwk in jwks["keys"] if jwk["kid"] == kid), None)
+        jwk = next((jwk for jwk in jwks.keys if jwk["kid"] == kid), None)
 
         if jwk:
             return jwk
 
         # try fetching from the cache.
-        jwk = next((jwk for jwk in self._jwks["keys"] if jwk["kid"] == kid), None)
+        jwk = next((jwk for jwk in self._jwks.keys if jwk["kid"] == kid), None)
         if jwk:
             return jwk
 
@@ -1184,12 +1189,34 @@ async def _fetch_jwks(self, kid: str, jwks: JWKS) -> JWK:
 
         raise AuthInvalidJwtError("JWT has no valid kid")
 
-    async def get_claims(self):
-        pass
+    async def get_claims(self, jwt: Optional[str] = None, jwks: Optional[JWKS] = None) -> Optional[ClaimsResponse]:
+        token = jwt
+        if not token:
+            session = await self.get_session()
+            if not session:
+                return None
+            
+            token = session.access_token
+
+        decoded_jwt = decode_jwt(token)
+        payload = decoded_jwt["payload"]
+        header = decoded_jwt["header"]
+        signature = decoded_jwt["signature"]
+
+        validate_exp(payload["exp"])
+
+        # if symetric algorithm, fallback to get_user
+        if 'kid' not in header or header["alg"] == 'HS256':
+            await self.get_user(token)
+            return ClaimsResponse(claims=payload, headers=header, signature=signature)
+
+        algorithm = get_algorithm(header["alg"])
+        signing_key = await self._fetch_jwks(header["kid"], jwks)
+        algorithm.prepare_key
 
 
 def parse_jwks(response: Any) -> JWKS:
     if "keys" not in response or len(response.keys) == 0:
         raise AuthInvalidJwtError("JWKS is empty")
 
-    return JWKS(keys=response.keys)
+    return JWKS(keys=response["keys"])
diff --git a/supabase_auth/constants.py b/supabase_auth/constants.py
index d16c0f31..5b0d1abc 100644
--- a/supabase_auth/constants.py
+++ b/supabase_auth/constants.py
@@ -21,3 +21,4 @@
         "name": "2024-01-01",
     },
 }
+BASE64URL_REGEX = r"^([a-z0-9_-]{4})*($|[a-z0-9_-]{3}$|[a-z0-9_-]{2}$)$"
\ No newline at end of file
diff --git a/supabase_auth/helpers.py b/supabase_auth/helpers.py
index 57cecf31..e46214c3 100644
--- a/supabase_auth/helpers.py
+++ b/supabase_auth/helpers.py
@@ -16,7 +16,7 @@
 from httpx import HTTPStatusError, Response
 from pydantic import BaseModel
 
-from .constants import API_VERSION_HEADER_NAME, API_VERSIONS
+from .constants import API_VERSION_HEADER_NAME, API_VERSIONS, BASE64URL_REGEX
 from .errors import (
     AuthApiError,
     AuthError,
@@ -40,7 +40,6 @@
 )
 
 TBaseModel = TypeVar("TBaseModel", bound=BaseModel)
-BASE64URL_REGEX = r"^([a-z0-9_-]{4})*($|[a-z0-9_-]{3}$|[a-z0-9_-]{2}$)$"
 
 
 def model_validate(model: Type[TBaseModel], contents) -> TBaseModel:
diff --git a/supabase_auth/types.py b/supabase_auth/types.py
index 5425173a..3d671cde 100644
--- a/supabase_auth/types.py
+++ b/supabase_auth/types.py
@@ -806,11 +806,11 @@ class RequiredClaims(TypedDict):
     session_id: str
 
 
-class JWTPayload(RequiredClaims, TypedDict, total=False):
+class JWTPayload(RequiredClaims, total=False):
     pass
 
 
-class JWK(TypedDict, total=False):
+class JWK(TypedDict):
     kty: Literal["RSA", "EC", "oct"]
     key_ops: List[str]
     alg: Optional[str]
@@ -821,6 +821,11 @@ class JWKS(TypedDict):
     keys: List[JWK]
 
 
+class ClaimsResponse(TypedDict):
+    claims: JWTPayload
+    headers: JWTHeader
+    signature: bytes
+
 for model in [
     AMREntry,
     AuthResponse,
diff --git a/tests/_async/test_gotrue.py b/tests/_async/test_gotrue.py
new file mode 100644
index 00000000..6d13addd
--- /dev/null
+++ b/tests/_async/test_gotrue.py
@@ -0,0 +1,24 @@
+import pytest
+from pytest_mock import mocker
+
+from .utils import mock_user_credentials
+
+from .clients import auth_client
+
+async def test_get_claims_returns_none_when_session_is_none():
+  claims = await auth_client().get_claims()
+  assert claims is None
+
+async def test_get_claims_calls_get_user_if_symmetric_jwt(mocker):
+  client = auth_client()
+  credentials = mock_user_credentials()
+  spy = mocker.spy(client, 'get_user')
+
+  response = await client.sign_up(credentials)
+  user = response.user
+  assert user is not None
+
+  response = await client.get_claims()
+  assert response["claims"]["email"] == user.email
+  spy.assert_called_once()
+  
\ No newline at end of file

From 8683823423dcf0c896eabb5483e4e78d6c9dd56b Mon Sep 17 00:00:00 2001
From: Guilherme Souza <guilherme@supabase.io>
Date: Wed, 12 Mar 2025 11:06:46 +0100
Subject: [PATCH 07/15] tests for asymmetric signature

---
 supabase_auth/_async/gotrue_client.py | 72 +++++++++++++++---------
 supabase_auth/_sync/gotrue_client.py  | 79 +++++++++++++++++++++------
 supabase_auth/constants.py            |  2 +-
 supabase_auth/helpers.py              | 11 +---
 supabase_auth/types.py                | 12 +---
 tests/_async/clients.py               | 12 ++++
 tests/_async/test_gotrue.py           | 48 ++++++++++------
 tests/_sync/clients.py                | 12 ++++
 tests/_sync/test_gotrue.py            | 41 ++++++++++++++
 9 files changed, 207 insertions(+), 82 deletions(-)
 create mode 100644 tests/_sync/test_gotrue.py

diff --git a/supabase_auth/_async/gotrue_client.py b/supabase_auth/_async/gotrue_client.py
index a99a1133..2fa8154c 100644
--- a/supabase_auth/_async/gotrue_client.py
+++ b/supabase_auth/_async/gotrue_client.py
@@ -4,10 +4,12 @@
 from functools import partial
 from json import loads
 from time import time
-from typing import Any, Callable, Dict, List, Optional, Tuple, TypedDict
+from typing import Any, Callable, Dict, List, Optional, Tuple
 from urllib.parse import parse_qs, urlencode, urlparse
 from uuid import uuid4
 
+from jwt import get_algorithm_by_name
+
 from ..constants import (
     DEFAULT_HEADERS,
     EXPIRY_MARGIN,
@@ -28,7 +30,6 @@
     decode_jwt,
     generate_pkce_challenge,
     generate_pkce_verifier,
-    get_algorithm,
     model_dump,
     model_dump_json,
     model_validate,
@@ -42,8 +43,6 @@
 from ..http_clients import AsyncClient
 from ..timer import Timer
 from ..types import (
-    JWK,
-    JWKS,
     AuthChangeEvent,
     AuthenticatorAssuranceLevels,
     AuthFlowType,
@@ -59,8 +58,6 @@
     CodeExchangeParams,
     DecodedJWTDict,
     IdentitiesResponse,
-    JWTHeader,
-    JWTPayload,
     MFAChallengeAndVerifyParams,
     MFAChallengeParams,
     MFAEnrollParams,
@@ -114,7 +111,7 @@ def __init__(
             verify=verify,
             proxy=proxy,
         )
-        self._jwks: JWKS = {}
+        self._jwks = {"keys": []}
         self._storage_key = storage_key or STORAGE_KEY
         self._auto_refresh_token = auto_refresh_token
         self._persist_session = persist_session
@@ -1161,26 +1158,32 @@ async def exchange_code_for_session(self, params: CodeExchangeParams):
             self._notify_all_subscribers("SIGNED_IN", response.session)
         return response
 
-    async def _fetch_jwks(self, kid: str, jwks: JWKS) -> JWK:
+    async def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]:
+        jwk: Dict[str, Any] = {}
+
         # try fetching from the suplied keys.
-        jwk = next((jwk for jwk in jwks.keys if jwk["kid"] == kid), None)
+        jwk = next((jwk for jwk in jwks.get("keys", []) if jwk.get("kid") == kid), None)
 
         if jwk:
             return jwk
 
-        # try fetching from the cache.
-        jwk = next((jwk for jwk in self._jwks.keys if jwk["kid"] == kid), None)
-        if jwk:
-            return jwk
+        if self._jwks:
+            # try fetching from the cache.
+            jwk = next(
+                (jwk for jwk in self._jwks.get("keys", []) if jwk.get("kid") == kid),
+                None,
+            )
+            if jwk:
+                return jwk
 
         # jwk isn't cached in memory so we need to fetch it from the well-known endpoint
         response = await self._request("GET", ".well-known/jwks.json", xform=parse_jwks)
-        if response.jwks:
-            self._jwks = response.jwks
+        if response:
+            self._jwks = response
 
             # find the signing key
             jwk = next(
-                (jwk for jwk in response.jwks["keys"] if jwk["kid"] == kid), None
+                (jwk for jwk in response.get("keys", []) if jwk.get("kid") == kid), None
             )
             if not jwk:
                 raise AuthInvalidJwtError("No matching signing key found in JWKS")
@@ -1189,13 +1192,15 @@ async def _fetch_jwks(self, kid: str, jwks: JWKS) -> JWK:
 
         raise AuthInvalidJwtError("JWT has no valid kid")
 
-    async def get_claims(self, jwt: Optional[str] = None, jwks: Optional[JWKS] = None) -> Optional[ClaimsResponse]:
+    async def get_claims(
+        self, jwt: Optional[str] = None, jwks: Optional[Dict[str, list]] = None
+    ) -> Optional[ClaimsResponse]:
         token = jwt
         if not token:
             session = await self.get_session()
             if not session:
                 return None
-            
+
             token = session.access_token
 
         decoded_jwt = decode_jwt(token)
@@ -1203,20 +1208,35 @@ async def get_claims(self, jwt: Optional[str] = None, jwks: Optional[JWKS] = Non
         header = decoded_jwt["header"]
         signature = decoded_jwt["signature"]
 
+        raw_header = decoded_jwt["raw"]["header"]
+        raw_payload = decoded_jwt["raw"]["payload"]
+
         validate_exp(payload["exp"])
 
-        # if symetric algorithm, fallback to get_user
-        if 'kid' not in header or header["alg"] == 'HS256':
+        # if symmetric algorithm, fallback to get_user
+        if "kid" not in header or header["alg"] == "HS256":
             await self.get_user(token)
             return ClaimsResponse(claims=payload, headers=header, signature=signature)
 
-        algorithm = get_algorithm(header["alg"])
-        signing_key = await self._fetch_jwks(header["kid"], jwks)
-        algorithm.prepare_key
+        algorithm = get_algorithm_by_name(header["alg"])
+        signing_key = algorithm.from_jwk(
+            await self._fetch_jwks(header["kid"], jwks or {})
+        )
+
+        # verify the signature
+        is_valid = algorithm.verify(
+            msg=f"{raw_header}.{raw_payload}".encode(), key=signing_key, sig=signature
+        )
+
+        if not is_valid:
+            raise AuthInvalidJwtError("Invalid JWT signature")
+
+        # If verification succeeds, decode and return claims
+        return ClaimsResponse(claims=payload, headers=header, signature=signature)
 
 
-def parse_jwks(response: Any) -> JWKS:
-    if "keys" not in response or len(response.keys) == 0:
+def parse_jwks(response: Any) -> Dict[str, list]:
+    if "keys" not in response or len(response["keys"]) == 0:
         raise AuthInvalidJwtError("JWKS is empty")
 
-    return JWKS(keys=response["keys"])
+    return response
diff --git a/supabase_auth/_sync/gotrue_client.py b/supabase_auth/_sync/gotrue_client.py
index 6e1cbed4..41c2de5d 100644
--- a/supabase_auth/_sync/gotrue_client.py
+++ b/supabase_auth/_sync/gotrue_client.py
@@ -8,6 +8,8 @@
 from urllib.parse import parse_qs, urlencode, urlparse
 from uuid import uuid4
 
+from jwt import get_algorithm_by_name
+
 from ..constants import (
     DEFAULT_HEADERS,
     EXPIRY_MARGIN,
@@ -36,12 +38,11 @@
     parse_link_identity_response,
     parse_sso_response,
     parse_user_response,
+    validate_exp,
 )
 from ..http_clients import SyncClient
 from ..timer import Timer
 from ..types import (
-    JWK,
-    JWKS,
     AuthChangeEvent,
     AuthenticatorAssuranceLevels,
     AuthFlowType,
@@ -53,6 +54,7 @@
     AuthMFAVerifyResponse,
     AuthOtpResponse,
     AuthResponse,
+    ClaimsResponse,
     CodeExchangeParams,
     DecodedJWTDict,
     IdentitiesResponse,
@@ -109,7 +111,7 @@ def __init__(
             verify=verify,
             proxy=proxy,
         )
-        self._jwks: JWKS = {}
+        self._jwks = {"keys": []}
         self._storage_key = storage_key or STORAGE_KEY
         self._auto_refresh_token = auto_refresh_token
         self._persist_session = persist_session
@@ -1150,26 +1152,32 @@ def exchange_code_for_session(self, params: CodeExchangeParams):
             self._notify_all_subscribers("SIGNED_IN", response.session)
         return response
 
-    def _fetch_jwks(self, kid: str, jwks: JWKS) -> JWK:
+    def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]:
+        jwk: Dict[str, Any] = {}
+
         # try fetching from the suplied keys.
-        jwk = next((jwk for jwk in jwks["keys"] if jwk["kid"] == kid), None)
+        jwk = next((jwk for jwk in jwks.get("keys", []) if jwk.get("kid") == kid), None)
 
         if jwk:
             return jwk
 
-        # try fetching from the cache.
-        jwk = next((jwk for jwk in self._jwks["keys"] if jwk["kid"] == kid), None)
-        if jwk:
-            return jwk
+        if self._jwks:
+            # try fetching from the cache.
+            jwk = next(
+                (jwk for jwk in self._jwks.get("keys", []) if jwk.get("kid") == kid),
+                None,
+            )
+            if jwk:
+                return jwk
 
         # jwk isn't cached in memory so we need to fetch it from the well-known endpoint
         response = self._request("GET", ".well-known/jwks.json", xform=parse_jwks)
-        if response.jwks:
-            self._jwks = response.jwks
+        if response:
+            self._jwks = response
 
             # find the signing key
             jwk = next(
-                (jwk for jwk in response.jwks["keys"] if jwk["kid"] == kid), None
+                (jwk for jwk in response.get("keys", []) if jwk.get("kid") == kid), None
             )
             if not jwk:
                 raise AuthInvalidJwtError("No matching signing key found in JWKS")
@@ -1178,12 +1186,49 @@ def _fetch_jwks(self, kid: str, jwks: JWKS) -> JWK:
 
         raise AuthInvalidJwtError("JWT has no valid kid")
 
-    def get_claims(self):
-        pass
+    def get_claims(
+        self, jwt: Optional[str] = None, jwks: Optional[Dict[str, list]] = None
+    ) -> Optional[ClaimsResponse]:
+        token = jwt
+        if not token:
+            session = self.get_session()
+            if not session:
+                return None
+
+            token = session.access_token
+
+        decoded_jwt = decode_jwt(token)
+        payload = decoded_jwt["payload"]
+        header = decoded_jwt["header"]
+        signature = decoded_jwt["signature"]
+
+        raw_header = decoded_jwt["raw"]["header"]
+        raw_payload = decoded_jwt["raw"]["payload"]
+
+        validate_exp(payload["exp"])
+
+        # if symmetric algorithm, fallback to get_user
+        if "kid" not in header or header["alg"] == "HS256":
+            self.get_user(token)
+            return ClaimsResponse(claims=payload, headers=header, signature=signature)
+
+        algorithm = get_algorithm_by_name(header["alg"])
+        signing_key = algorithm.from_jwk(self._fetch_jwks(header["kid"], jwks or {}))
+
+        # verify the signature
+        is_valid = algorithm.verify(
+            msg=f"{raw_header}.{raw_payload}".encode(), key=signing_key, sig=signature
+        )
+
+        if not is_valid:
+            raise AuthInvalidJwtError("Invalid JWT signature")
+
+        # If verification succeeds, decode and return claims
+        return ClaimsResponse(claims=payload, headers=header, signature=signature)
 
 
-def parse_jwks(response: Any) -> JWKS:
-    if "keys" not in response or len(response.keys) == 0:
+def parse_jwks(response: Any) -> Dict[str, list]:
+    if "keys" not in response or len(response["keys"]) == 0:
         raise AuthInvalidJwtError("JWKS is empty")
 
-    return JWKS(keys=response.keys)
+    return response
diff --git a/supabase_auth/constants.py b/supabase_auth/constants.py
index 5b0d1abc..671510e5 100644
--- a/supabase_auth/constants.py
+++ b/supabase_auth/constants.py
@@ -21,4 +21,4 @@
         "name": "2024-01-01",
     },
 }
-BASE64URL_REGEX = r"^([a-z0-9_-]{4})*($|[a-z0-9_-]{3}$|[a-z0-9_-]{2}$)$"
\ No newline at end of file
+BASE64URL_REGEX = r"^([a-z0-9_-]{4})*($|[a-z0-9_-]{3}$|[a-z0-9_-]{2}$)$"
diff --git a/supabase_auth/helpers.py b/supabase_auth/helpers.py
index e46214c3..65c68ad4 100644
--- a/supabase_auth/helpers.py
+++ b/supabase_auth/helpers.py
@@ -8,11 +8,9 @@
 from base64 import urlsafe_b64decode
 from datetime import datetime
 from json import loads
-from typing import Any, Dict, Literal, Optional, Type, TypedDict, TypeVar, cast
+from typing import Any, Dict, Optional, Type, TypedDict, TypeVar, cast
 from urllib.parse import urlparse
 
-import jwt
-import jwt.algorithms
 from httpx import HTTPStatusError, Response
 from pydantic import BaseModel
 
@@ -311,10 +309,3 @@ def validate_exp(exp: int) -> None:
     time_now = datetime.now().timestamp()
     if exp <= time_now:
         raise AuthInvalidJwtError("JWT has expired")
-
-
-def get_algorithm(alg: Literal["RS256", "ES256"]) -> jwt.algorithms.Algorithm:
-    if alg == "RS256":
-        return jwt.algorithms.RSAAlgorithm
-    elif alg == "ES256":
-        return jwt.algorithms.ECAlgorithm
diff --git a/supabase_auth/types.py b/supabase_auth/types.py
index 3d671cde..4ef3c1ac 100644
--- a/supabase_auth/types.py
+++ b/supabase_auth/types.py
@@ -810,22 +810,12 @@ class JWTPayload(RequiredClaims, total=False):
     pass
 
 
-class JWK(TypedDict):
-    kty: Literal["RSA", "EC", "oct"]
-    key_ops: List[str]
-    alg: Optional[str]
-    kid: Optional[str]
-
-
-class JWKS(TypedDict):
-    keys: List[JWK]
-
-
 class ClaimsResponse(TypedDict):
     claims: JWTPayload
     headers: JWTHeader
     signature: bytes
 
+
 for model in [
     AMREntry,
     AuthResponse,
diff --git a/tests/_async/clients.py b/tests/_async/clients.py
index 3babf2ee..03356714 100644
--- a/tests/_async/clients.py
+++ b/tests/_async/clients.py
@@ -5,6 +5,7 @@
 SIGNUP_ENABLED_AUTO_CONFIRM_OFF_PORT = 9999
 SIGNUP_ENABLED_AUTO_CONFIRM_ON_PORT = 9998
 SIGNUP_DISABLED_AUTO_CONFIRM_OFF_PORT = 9997
+SIGNUP_ENABLED_ASYMMETRIC_AUTO_CONFIRM_ON_PORT = 9996
 
 GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_OFF = (
     f"http://localhost:{SIGNUP_ENABLED_AUTO_CONFIRM_OFF_PORT}"
@@ -12,6 +13,9 @@
 GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_ON = (
     f"http://localhost:{SIGNUP_ENABLED_AUTO_CONFIRM_ON_PORT}"
 )
+GOTRUE_URL_SIGNUP_ENABLED_ASYMMETRIC_AUTO_CONFIRM_ON = (
+    f"http://localhost:{SIGNUP_ENABLED_ASYMMETRIC_AUTO_CONFIRM_ON_PORT}"
+)
 GOTRUE_URL_SIGNUP_DISABLED_AUTO_CONFIRM_OFF = (
     f"http://localhost:{SIGNUP_DISABLED_AUTO_CONFIRM_OFF_PORT}"
 )
@@ -43,6 +47,14 @@ def auth_client_with_session():
     )
 
 
+def auth_client_with_asymmetric_session() -> AsyncGoTrueClient:
+    return AsyncGoTrueClient(
+        url=GOTRUE_URL_SIGNUP_ENABLED_ASYMMETRIC_AUTO_CONFIRM_ON,
+        auto_refresh_token=False,
+        persist_session=False,
+    )
+
+
 def auth_subscription_client():
     return AsyncGoTrueClient(
         url=GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_ON,
diff --git a/tests/_async/test_gotrue.py b/tests/_async/test_gotrue.py
index 6d13addd..05adf67a 100644
--- a/tests/_async/test_gotrue.py
+++ b/tests/_async/test_gotrue.py
@@ -1,24 +1,38 @@
-import pytest
-from pytest_mock import mocker
+import unittest
 
+from .clients import auth_client, auth_client_with_asymmetric_session
 from .utils import mock_user_credentials
 
-from .clients import auth_client
 
 async def test_get_claims_returns_none_when_session_is_none():
-  claims = await auth_client().get_claims()
-  assert claims is None
+    claims = await auth_client().get_claims()
+    assert claims is None
+
 
 async def test_get_claims_calls_get_user_if_symmetric_jwt(mocker):
-  client = auth_client()
-  credentials = mock_user_credentials()
-  spy = mocker.spy(client, 'get_user')
-
-  response = await client.sign_up(credentials)
-  user = response.user
-  assert user is not None
-
-  response = await client.get_claims()
-  assert response["claims"]["email"] == user.email
-  spy.assert_called_once()
-  
\ No newline at end of file
+    client = auth_client()
+    spy = mocker.spy(client, "get_user")
+
+    user = (await client.sign_up(mock_user_credentials())).user
+    assert user is not None
+
+    claims = (await client.get_claims())["claims"]
+    assert claims["email"] == user.email
+    spy.assert_called_once()
+
+
+async def test_get_claims_fetches_jwks_to_verify_asymmetric_jwt(mocker):
+    client = auth_client_with_asymmetric_session()
+
+    user = (await client.sign_up(mock_user_credentials())).user
+    assert user is not None
+
+    spy = mocker.spy(client, "_request")
+
+    claims = (await client.get_claims())["claims"]
+    assert claims["email"] == user.email
+
+    spy.assert_called_once()
+    spy.assert_called_with("GET", ".well-known/jwks.json", xform=unittest.mock.ANY)
+
+    assert len(spy.spy_return.get("keys")) > 0
diff --git a/tests/_sync/clients.py b/tests/_sync/clients.py
index 50086c2a..3fee59d1 100644
--- a/tests/_sync/clients.py
+++ b/tests/_sync/clients.py
@@ -5,6 +5,7 @@
 SIGNUP_ENABLED_AUTO_CONFIRM_OFF_PORT = 9999
 SIGNUP_ENABLED_AUTO_CONFIRM_ON_PORT = 9998
 SIGNUP_DISABLED_AUTO_CONFIRM_OFF_PORT = 9997
+SIGNUP_ENABLED_ASYMMETRIC_AUTO_CONFIRM_ON_PORT = 9996
 
 GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_OFF = (
     f"http://localhost:{SIGNUP_ENABLED_AUTO_CONFIRM_OFF_PORT}"
@@ -12,6 +13,9 @@
 GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_ON = (
     f"http://localhost:{SIGNUP_ENABLED_AUTO_CONFIRM_ON_PORT}"
 )
+GOTRUE_URL_SIGNUP_ENABLED_ASYMMETRIC_AUTO_CONFIRM_ON = (
+    f"http://localhost:{SIGNUP_ENABLED_ASYMMETRIC_AUTO_CONFIRM_ON_PORT}"
+)
 GOTRUE_URL_SIGNUP_DISABLED_AUTO_CONFIRM_OFF = (
     f"http://localhost:{SIGNUP_DISABLED_AUTO_CONFIRM_OFF_PORT}"
 )
@@ -43,6 +47,14 @@ def auth_client_with_session():
     )
 
 
+def auth_client_with_asymmetric_session() -> SyncGoTrueClient:
+    return SyncGoTrueClient(
+        url=GOTRUE_URL_SIGNUP_ENABLED_ASYMMETRIC_AUTO_CONFIRM_ON,
+        auto_refresh_token=False,
+        persist_session=False,
+    )
+
+
 def auth_subscription_client():
     return SyncGoTrueClient(
         url=GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_ON,
diff --git a/tests/_sync/test_gotrue.py b/tests/_sync/test_gotrue.py
new file mode 100644
index 00000000..c0b576ad
--- /dev/null
+++ b/tests/_sync/test_gotrue.py
@@ -0,0 +1,41 @@
+import unittest
+import pytest
+from pytest_mock import mocker
+
+from .utils import mock_user_credentials
+
+from .clients import auth_client, auth_client_with_asymmetric_session
+
+def test_get_claims_returns_none_when_session_is_none():
+  claims = auth_client().get_claims()
+  assert claims is None
+
+def test_get_claims_calls_get_user_if_symmetric_jwt(mocker):
+  client = auth_client()
+  spy = mocker.spy(client, 'get_user')
+
+  user = (client.sign_up(mock_user_credentials())).user
+  assert user is not None
+
+  claims = (client.get_claims())["claims"]
+  assert claims["email"] == user.email
+  spy.assert_called_once()
+  
+
+def test_get_claims_fetches_jwks_to_verify_asymmetric_jwt(mocker):
+  client = auth_client_with_asymmetric_session()
+
+  user = (client.sign_up(mock_user_credentials())).user
+  assert user is not None
+
+  spy = mocker.spy(client, "_request")
+
+  claims = (client.get_claims())["claims"]
+  assert claims["email"] == user.email
+
+  spy.assert_called_once()
+  spy.assert_called_with("GET", ".well-known/jwks.json", xform=unittest.mock.ANY)
+
+  assert len(spy.spy_return.get("keys")) > 0
+
+  
\ No newline at end of file

From 3d92584bb15f57017a798f3c3b6e88e5681bec4f Mon Sep 17 00:00:00 2001
From: Guilherme Souza <guilherme@supabase.io>
Date: Wed, 12 Mar 2025 11:20:23 +0100
Subject: [PATCH 08/15] add typed dicts for convenience

---
 supabase_auth/_async/gotrue_client.py | 38 +++++++++++---------
 supabase_auth/_sync/gotrue_client.py  | 52 ++++++++++++++++-----------
 supabase_auth/types.py                | 11 ++++++
 tests/_sync/test_gotrue.py            | 45 +++++++++++------------
 tests/_sync/test_gotrue_admin_api.py  | 34 +++++++++---------
 5 files changed, 101 insertions(+), 79 deletions(-)

diff --git a/supabase_auth/_async/gotrue_client.py b/supabase_auth/_async/gotrue_client.py
index 2fa8154c..452ee612 100644
--- a/supabase_auth/_async/gotrue_client.py
+++ b/supabase_auth/_async/gotrue_client.py
@@ -43,6 +43,7 @@
 from ..http_clients import AsyncClient
 from ..timer import Timer
 from ..types import (
+    JWK,
     AuthChangeEvent,
     AuthenticatorAssuranceLevels,
     AuthFlowType,
@@ -58,6 +59,7 @@
     CodeExchangeParams,
     DecodedJWTDict,
     IdentitiesResponse,
+    JWKSet,
     MFAChallengeAndVerifyParams,
     MFAChallengeParams,
     MFAEnrollParams,
@@ -111,7 +113,7 @@ def __init__(
             verify=verify,
             proxy=proxy,
         )
-        self._jwks = {"keys": []}
+        self._jwks: JWKSet = {"keys": []}
         self._storage_key = storage_key or STORAGE_KEY
         self._auto_refresh_token = auto_refresh_token
         self._persist_session = persist_session
@@ -1158,11 +1160,11 @@ async def exchange_code_for_session(self, params: CodeExchangeParams):
             self._notify_all_subscribers("SIGNED_IN", response.session)
         return response
 
-    async def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]:
-        jwk: Dict[str, Any] = {}
+    async def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK:
+        jwk: Optional[JWK] = None
 
         # try fetching from the suplied keys.
-        jwk = next((jwk for jwk in jwks.get("keys", []) if jwk.get("kid") == kid), None)
+        jwk = next((jwk for jwk in jwks["keys"] if jwk["kid"] == kid), None)
 
         if jwk:
             return jwk
@@ -1170,7 +1172,7 @@ async def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]:
         if self._jwks:
             # try fetching from the cache.
             jwk = next(
-                (jwk for jwk in self._jwks.get("keys", []) if jwk.get("kid") == kid),
+                (jwk for jwk in self._jwks["keys"] if jwk["kid"] == kid),
                 None,
             )
             if jwk:
@@ -1182,9 +1184,7 @@ async def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]:
             self._jwks = response
 
             # find the signing key
-            jwk = next(
-                (jwk for jwk in response.get("keys", []) if jwk.get("kid") == kid), None
-            )
+            jwk = next((jwk for jwk in response["keys"] if jwk["kid"] == kid), None)
             if not jwk:
                 raise AuthInvalidJwtError("No matching signing key found in JWKS")
 
@@ -1193,7 +1193,7 @@ async def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]:
         raise AuthInvalidJwtError("JWT has no valid kid")
 
     async def get_claims(
-        self, jwt: Optional[str] = None, jwks: Optional[Dict[str, list]] = None
+        self, jwt: Optional[str] = None, jwks: Optional[JWKSet] = None
     ) -> Optional[ClaimsResponse]:
         token = jwt
         if not token:
@@ -1204,12 +1204,16 @@ async def get_claims(
             token = session.access_token
 
         decoded_jwt = decode_jwt(token)
-        payload = decoded_jwt["payload"]
-        header = decoded_jwt["header"]
-        signature = decoded_jwt["signature"]
 
-        raw_header = decoded_jwt["raw"]["header"]
-        raw_payload = decoded_jwt["raw"]["payload"]
+        payload, header, signature = (
+            decoded_jwt["payload"],
+            decoded_jwt["header"],
+            decoded_jwt["signature"],
+        )
+        raw_header, raw_payload = (
+            decoded_jwt["raw"]["header"],
+            decoded_jwt["raw"]["payload"],
+        )
 
         validate_exp(payload["exp"])
 
@@ -1220,7 +1224,7 @@ async def get_claims(
 
         algorithm = get_algorithm_by_name(header["alg"])
         signing_key = algorithm.from_jwk(
-            await self._fetch_jwks(header["kid"], jwks or {})
+            await self._fetch_jwks(header["kid"], jwks or {"keys": []})
         )
 
         # verify the signature
@@ -1235,8 +1239,8 @@ async def get_claims(
         return ClaimsResponse(claims=payload, headers=header, signature=signature)
 
 
-def parse_jwks(response: Any) -> Dict[str, list]:
+def parse_jwks(response: Any) -> JWKSet:
     if "keys" not in response or len(response["keys"]) == 0:
         raise AuthInvalidJwtError("JWKS is empty")
 
-    return response
+    return {"keys": response["keys"]}
diff --git a/supabase_auth/_sync/gotrue_client.py b/supabase_auth/_sync/gotrue_client.py
index 41c2de5d..9a4404f8 100644
--- a/supabase_auth/_sync/gotrue_client.py
+++ b/supabase_auth/_sync/gotrue_client.py
@@ -43,6 +43,7 @@
 from ..http_clients import SyncClient
 from ..timer import Timer
 from ..types import (
+    JWK,
     AuthChangeEvent,
     AuthenticatorAssuranceLevels,
     AuthFlowType,
@@ -58,6 +59,7 @@
     CodeExchangeParams,
     DecodedJWTDict,
     IdentitiesResponse,
+    JWKSet,
     MFAChallengeAndVerifyParams,
     MFAChallengeParams,
     MFAEnrollParams,
@@ -111,7 +113,7 @@ def __init__(
             verify=verify,
             proxy=proxy,
         )
-        self._jwks = {"keys": []}
+        self._jwks: JWKSet = {"keys": []}
         self._storage_key = storage_key or STORAGE_KEY
         self._auto_refresh_token = auto_refresh_token
         self._persist_session = persist_session
@@ -421,7 +423,9 @@ def sign_in_with_oauth(
         )
         return OAuthResponse(provider=provider, url=url_with_qs)
 
-    def link_identity(self, credentials: SignInWithOAuthCredentials) -> OAuthResponse:
+    def link_identity(
+        self, credentials: SignInWithOAuthCredentials
+    ) -> OAuthResponse:
         provider = credentials.get("provider")
         options = credentials.get("options", {})
         redirect_to = options.get("redirect_to")
@@ -704,7 +708,9 @@ def set_session(self, access_token: str, refresh_token: str) -> AuthResponse:
         self._notify_all_subscribers("TOKEN_REFRESHED", session)
         return AuthResponse(session=session, user=response.user)
 
-    def refresh_session(self, refresh_token: Optional[str] = None) -> AuthResponse:
+    def refresh_session(
+        self, refresh_token: Optional[str] = None
+    ) -> AuthResponse:
         """
         Returns a new session, regardless of expiry status.
 
@@ -1113,7 +1119,9 @@ def _get_url_for_provider(
         if self._flow_type == "pkce":
             code_verifier = generate_pkce_verifier()
             code_challenge = generate_pkce_challenge(code_verifier)
-            self._storage.set_item(f"{self._storage_key}-code-verifier", code_verifier)
+            self._storage.set_item(
+                f"{self._storage_key}-code-verifier", code_verifier
+            )
             code_challenge_method = (
                 "plain" if code_verifier == code_challenge else "s256"
             )
@@ -1152,11 +1160,11 @@ def exchange_code_for_session(self, params: CodeExchangeParams):
             self._notify_all_subscribers("SIGNED_IN", response.session)
         return response
 
-    def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]:
-        jwk: Dict[str, Any] = {}
+    def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK:
+        jwk: Optional[JWK] = None
 
         # try fetching from the suplied keys.
-        jwk = next((jwk for jwk in jwks.get("keys", []) if jwk.get("kid") == kid), None)
+        jwk = next((jwk for jwk in jwks["keys"] if jwk["kid"] == kid), None)
 
         if jwk:
             return jwk
@@ -1164,7 +1172,7 @@ def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]:
         if self._jwks:
             # try fetching from the cache.
             jwk = next(
-                (jwk for jwk in self._jwks.get("keys", []) if jwk.get("kid") == kid),
+                (jwk for jwk in self._jwks["keys"] if jwk["kid"] == kid),
                 None,
             )
             if jwk:
@@ -1176,9 +1184,7 @@ def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]:
             self._jwks = response
 
             # find the signing key
-            jwk = next(
-                (jwk for jwk in response.get("keys", []) if jwk.get("kid") == kid), None
-            )
+            jwk = next((jwk for jwk in response["keys"] if jwk["kid"] == kid), None)
             if not jwk:
                 raise AuthInvalidJwtError("No matching signing key found in JWKS")
 
@@ -1187,7 +1193,7 @@ def _fetch_jwks(self, kid: str, jwks: Dict[str, list]) -> Dict[str, Any]:
         raise AuthInvalidJwtError("JWT has no valid kid")
 
     def get_claims(
-        self, jwt: Optional[str] = None, jwks: Optional[Dict[str, list]] = None
+        self, jwt: Optional[str] = None, jwks: Optional[JWKSet] = None
     ) -> Optional[ClaimsResponse]:
         token = jwt
         if not token:
@@ -1198,12 +1204,16 @@ def get_claims(
             token = session.access_token
 
         decoded_jwt = decode_jwt(token)
-        payload = decoded_jwt["payload"]
-        header = decoded_jwt["header"]
-        signature = decoded_jwt["signature"]
 
-        raw_header = decoded_jwt["raw"]["header"]
-        raw_payload = decoded_jwt["raw"]["payload"]
+        payload, header, signature = (
+            decoded_jwt["payload"],
+            decoded_jwt["header"],
+            decoded_jwt["signature"],
+        )
+        raw_header, raw_payload = (
+            decoded_jwt["raw"]["header"],
+            decoded_jwt["raw"]["payload"],
+        )
 
         validate_exp(payload["exp"])
 
@@ -1213,7 +1223,9 @@ def get_claims(
             return ClaimsResponse(claims=payload, headers=header, signature=signature)
 
         algorithm = get_algorithm_by_name(header["alg"])
-        signing_key = algorithm.from_jwk(self._fetch_jwks(header["kid"], jwks or {}))
+        signing_key = algorithm.from_jwk(
+            self._fetch_jwks(header["kid"], jwks or {"keys": []})
+        )
 
         # verify the signature
         is_valid = algorithm.verify(
@@ -1227,8 +1239,8 @@ def get_claims(
         return ClaimsResponse(claims=payload, headers=header, signature=signature)
 
 
-def parse_jwks(response: Any) -> Dict[str, list]:
+def parse_jwks(response: Any) -> JWKSet:
     if "keys" not in response or len(response["keys"]) == 0:
         raise AuthInvalidJwtError("JWKS is empty")
 
-    return response
+    return {"keys": response["keys"]}
diff --git a/supabase_auth/types.py b/supabase_auth/types.py
index 4ef3c1ac..991a27b4 100644
--- a/supabase_auth/types.py
+++ b/supabase_auth/types.py
@@ -816,6 +816,17 @@ class ClaimsResponse(TypedDict):
     signature: bytes
 
 
+class JWK(TypedDict, total=False):
+    kty: Literal["RSA", "EC", "oct"]
+    key_ops: List[str]
+    alg: Optional[str]
+    kid: Optional[str]
+
+
+class JWKSet(TypedDict):
+    keys: List[JWK]
+
+
 for model in [
     AMREntry,
     AuthResponse,
diff --git a/tests/_sync/test_gotrue.py b/tests/_sync/test_gotrue.py
index c0b576ad..0f1082b4 100644
--- a/tests/_sync/test_gotrue.py
+++ b/tests/_sync/test_gotrue.py
@@ -1,41 +1,38 @@
 import unittest
-import pytest
-from pytest_mock import mocker
 
+from .clients import auth_client, auth_client_with_asymmetric_session
 from .utils import mock_user_credentials
 
-from .clients import auth_client, auth_client_with_asymmetric_session
 
 def test_get_claims_returns_none_when_session_is_none():
-  claims = auth_client().get_claims()
-  assert claims is None
+    claims = auth_client().get_claims()
+    assert claims is None
+
 
 def test_get_claims_calls_get_user_if_symmetric_jwt(mocker):
-  client = auth_client()
-  spy = mocker.spy(client, 'get_user')
+    client = auth_client()
+    spy = mocker.spy(client, "get_user")
 
-  user = (client.sign_up(mock_user_credentials())).user
-  assert user is not None
+    user = (client.sign_up(mock_user_credentials())).user
+    assert user is not None
 
-  claims = (client.get_claims())["claims"]
-  assert claims["email"] == user.email
-  spy.assert_called_once()
-  
+    claims = (client.get_claims())["claims"]
+    assert claims["email"] == user.email
+    spy.assert_called_once()
 
-def test_get_claims_fetches_jwks_to_verify_asymmetric_jwt(mocker):
-  client = auth_client_with_asymmetric_session()
 
-  user = (client.sign_up(mock_user_credentials())).user
-  assert user is not None
+def test_get_claims_fetches_jwks_to_verify_asymmetric_jwt(mocker):
+    client = auth_client_with_asymmetric_session()
 
-  spy = mocker.spy(client, "_request")
+    user = (client.sign_up(mock_user_credentials())).user
+    assert user is not None
 
-  claims = (client.get_claims())["claims"]
-  assert claims["email"] == user.email
+    spy = mocker.spy(client, "_request")
 
-  spy.assert_called_once()
-  spy.assert_called_with("GET", ".well-known/jwks.json", xform=unittest.mock.ANY)
+    claims = (client.get_claims())["claims"]
+    assert claims["email"] == user.email
 
-  assert len(spy.spy_return.get("keys")) > 0
+    spy.assert_called_once()
+    spy.assert_called_with("GET", ".well-known/jwks.json", xform=unittest.mock.ANY)
 
-  
\ No newline at end of file
+    assert len(spy.spy_return.get("keys")) > 0
diff --git a/tests/_sync/test_gotrue_admin_api.py b/tests/_sync/test_gotrue_admin_api.py
index dd5fa523..2df5352a 100644
--- a/tests/_sync/test_gotrue_admin_api.py
+++ b/tests/_sync/test_gotrue_admin_api.py
@@ -156,13 +156,11 @@ def test_modify_confirm_email_using_update_user_by_id():
 
 def test_invalid_credential_sign_in_with_phone():
     try:
-        response = (
-            client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password(
-                {
-                    "phone": "+123456789",
-                    "password": "strong_pwd",
-                }
-            )
+        response = client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password(
+            {
+                "phone": "+123456789",
+                "password": "strong_pwd",
+            }
         )
     except AuthApiError as e:
         assert e.to_dict()
@@ -170,13 +168,11 @@ def test_invalid_credential_sign_in_with_phone():
 
 def test_invalid_credential_sign_in_with_email():
     try:
-        response = (
-            client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password(
-                {
-                    "email": "unknown_user@unknowndomain.com",
-                    "password": "strong_pwd",
-                }
-            )
+        response = client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password(
+            {
+                "email": "unknown_user@unknowndomain.com",
+                "password": "strong_pwd",
+            }
         )
     except AuthApiError as e:
         assert e.to_dict()
@@ -390,10 +386,12 @@ def test_sign_in_with_sso():
 
 
 def test_sign_in_with_oauth():
-    assert client_api_auto_confirm_off_signups_enabled_client().sign_in_with_oauth(
-        {
-            "provider": "google",
-        }
+    assert (
+        client_api_auto_confirm_off_signups_enabled_client().sign_in_with_oauth(
+            {
+                "provider": "google",
+            }
+        )
     )
 
 

From 46c4c373a1b608ea1f4801c25f812fa8588a11c4 Mon Sep 17 00:00:00 2001
From: Guilherme Souza <guilherme@supabase.io>
Date: Wed, 12 Mar 2025 11:22:01 +0100
Subject: [PATCH 09/15] poetry lock

---
 poetry.lock | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/poetry.lock b/poetry.lock
index 1020f882..8d1c337b 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1481,4 +1481,4 @@ files = [
 [metadata]
 lock-version = "2.0"
 python-versions = "^3.9"
-content-hash = "b43d599434e02d17b2037adbe56a99a89a948127d180acdd27139e748047c798"
+content-hash = "f6e106fec6347fbb9991f6587f507e8ff984db84afd392cd5ee5d5ab682f833b"

From c001f79b22c9d4c60862cf8d6397aaeee850dda6 Mon Sep 17 00:00:00 2001
From: Guilherme Souza <guilherme@supabase.io>
Date: Wed, 12 Mar 2025 11:55:00 +0100
Subject: [PATCH 10/15] fix code format

---
 supabase_auth/_sync/gotrue_client.py | 12 +++-------
 tests/_sync/test_gotrue_admin_api.py | 34 +++++++++++++++-------------
 2 files changed, 21 insertions(+), 25 deletions(-)

diff --git a/supabase_auth/_sync/gotrue_client.py b/supabase_auth/_sync/gotrue_client.py
index 9a4404f8..e0535f08 100644
--- a/supabase_auth/_sync/gotrue_client.py
+++ b/supabase_auth/_sync/gotrue_client.py
@@ -423,9 +423,7 @@ def sign_in_with_oauth(
         )
         return OAuthResponse(provider=provider, url=url_with_qs)
 
-    def link_identity(
-        self, credentials: SignInWithOAuthCredentials
-    ) -> OAuthResponse:
+    def link_identity(self, credentials: SignInWithOAuthCredentials) -> OAuthResponse:
         provider = credentials.get("provider")
         options = credentials.get("options", {})
         redirect_to = options.get("redirect_to")
@@ -708,9 +706,7 @@ def set_session(self, access_token: str, refresh_token: str) -> AuthResponse:
         self._notify_all_subscribers("TOKEN_REFRESHED", session)
         return AuthResponse(session=session, user=response.user)
 
-    def refresh_session(
-        self, refresh_token: Optional[str] = None
-    ) -> AuthResponse:
+    def refresh_session(self, refresh_token: Optional[str] = None) -> AuthResponse:
         """
         Returns a new session, regardless of expiry status.
 
@@ -1119,9 +1115,7 @@ def _get_url_for_provider(
         if self._flow_type == "pkce":
             code_verifier = generate_pkce_verifier()
             code_challenge = generate_pkce_challenge(code_verifier)
-            self._storage.set_item(
-                f"{self._storage_key}-code-verifier", code_verifier
-            )
+            self._storage.set_item(f"{self._storage_key}-code-verifier", code_verifier)
             code_challenge_method = (
                 "plain" if code_verifier == code_challenge else "s256"
             )
diff --git a/tests/_sync/test_gotrue_admin_api.py b/tests/_sync/test_gotrue_admin_api.py
index 2df5352a..dd5fa523 100644
--- a/tests/_sync/test_gotrue_admin_api.py
+++ b/tests/_sync/test_gotrue_admin_api.py
@@ -156,11 +156,13 @@ def test_modify_confirm_email_using_update_user_by_id():
 
 def test_invalid_credential_sign_in_with_phone():
     try:
-        response = client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password(
-            {
-                "phone": "+123456789",
-                "password": "strong_pwd",
-            }
+        response = (
+            client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password(
+                {
+                    "phone": "+123456789",
+                    "password": "strong_pwd",
+                }
+            )
         )
     except AuthApiError as e:
         assert e.to_dict()
@@ -168,11 +170,13 @@ def test_invalid_credential_sign_in_with_phone():
 
 def test_invalid_credential_sign_in_with_email():
     try:
-        response = client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password(
-            {
-                "email": "unknown_user@unknowndomain.com",
-                "password": "strong_pwd",
-            }
+        response = (
+            client_api_auto_confirm_off_signups_enabled_client().sign_in_with_password(
+                {
+                    "email": "unknown_user@unknowndomain.com",
+                    "password": "strong_pwd",
+                }
+            )
         )
     except AuthApiError as e:
         assert e.to_dict()
@@ -386,12 +390,10 @@ def test_sign_in_with_sso():
 
 
 def test_sign_in_with_oauth():
-    assert (
-        client_api_auto_confirm_off_signups_enabled_client().sign_in_with_oauth(
-            {
-                "provider": "google",
-            }
-        )
+    assert client_api_auto_confirm_off_signups_enabled_client().sign_in_with_oauth(
+        {
+            "provider": "google",
+        }
     )
 
 

From 1d72f0740bb4cbe79b143357b948b8a71746208b Mon Sep 17 00:00:00 2001
From: Guilherme Souza <guilherme@supabase.io>
Date: Wed, 12 Mar 2025 14:25:42 +0100
Subject: [PATCH 11/15] improve test by asserting cache keys

---
 supabase_auth/_async/gotrue_client.py | 10 ++--------
 supabase_auth/_sync/gotrue_client.py  | 10 ++--------
 supabase_auth/helpers.py              |  8 ++++++++
 tests/_async/test_gotrue.py           |  5 ++++-
 tests/_sync/test_gotrue.py            |  5 ++++-
 5 files changed, 20 insertions(+), 18 deletions(-)

diff --git a/supabase_auth/_async/gotrue_client.py b/supabase_auth/_async/gotrue_client.py
index 452ee612..2669817d 100644
--- a/supabase_auth/_async/gotrue_client.py
+++ b/supabase_auth/_async/gotrue_client.py
@@ -4,7 +4,7 @@
 from functools import partial
 from json import loads
 from time import time
-from typing import Any, Callable, Dict, List, Optional, Tuple
+from typing import Callable, Dict, List, Optional, Tuple
 from urllib.parse import parse_qs, urlencode, urlparse
 from uuid import uuid4
 
@@ -35,6 +35,7 @@
     model_validate,
     parse_auth_otp_response,
     parse_auth_response,
+    parse_jwks,
     parse_link_identity_response,
     parse_sso_response,
     parse_user_response,
@@ -1237,10 +1238,3 @@ async def get_claims(
 
         # If verification succeeds, decode and return claims
         return ClaimsResponse(claims=payload, headers=header, signature=signature)
-
-
-def parse_jwks(response: Any) -> JWKSet:
-    if "keys" not in response or len(response["keys"]) == 0:
-        raise AuthInvalidJwtError("JWKS is empty")
-
-    return {"keys": response["keys"]}
diff --git a/supabase_auth/_sync/gotrue_client.py b/supabase_auth/_sync/gotrue_client.py
index e0535f08..c161cc22 100644
--- a/supabase_auth/_sync/gotrue_client.py
+++ b/supabase_auth/_sync/gotrue_client.py
@@ -4,7 +4,7 @@
 from functools import partial
 from json import loads
 from time import time
-from typing import Any, Callable, Dict, List, Optional, Tuple
+from typing import Callable, Dict, List, Optional, Tuple
 from urllib.parse import parse_qs, urlencode, urlparse
 from uuid import uuid4
 
@@ -35,6 +35,7 @@
     model_validate,
     parse_auth_otp_response,
     parse_auth_response,
+    parse_jwks,
     parse_link_identity_response,
     parse_sso_response,
     parse_user_response,
@@ -1231,10 +1232,3 @@ def get_claims(
 
         # If verification succeeds, decode and return claims
         return ClaimsResponse(claims=payload, headers=header, signature=signature)
-
-
-def parse_jwks(response: Any) -> JWKSet:
-    if "keys" not in response or len(response["keys"]) == 0:
-        raise AuthInvalidJwtError("JWKS is empty")
-
-    return {"keys": response["keys"]}
diff --git a/supabase_auth/helpers.py b/supabase_auth/helpers.py
index 65c68ad4..9aebab0d 100644
--- a/supabase_auth/helpers.py
+++ b/supabase_auth/helpers.py
@@ -28,6 +28,7 @@
     AuthResponse,
     GenerateLinkProperties,
     GenerateLinkResponse,
+    JWKSet,
     JWTHeader,
     JWTPayload,
     LinkIdentityResponse,
@@ -119,6 +120,13 @@ def parse_sso_response(data: Any) -> SSOResponse:
     return model_validate(SSOResponse, data)
 
 
+def parse_jwks(response: Any) -> JWKSet:
+    if "keys" not in response or len(response["keys"]) == 0:
+        raise AuthInvalidJwtError("JWKS is empty")
+
+    return {"keys": response["keys"]}
+
+
 def get_error_message(error: Any) -> str:
     props = ["msg", "message", "error_description", "error"]
     filter = lambda prop: (
diff --git a/tests/_async/test_gotrue.py b/tests/_async/test_gotrue.py
index 05adf67a..4c39c29e 100644
--- a/tests/_async/test_gotrue.py
+++ b/tests/_async/test_gotrue.py
@@ -35,4 +35,7 @@ async def test_get_claims_fetches_jwks_to_verify_asymmetric_jwt(mocker):
     spy.assert_called_once()
     spy.assert_called_with("GET", ".well-known/jwks.json", xform=unittest.mock.ANY)
 
-    assert len(spy.spy_return.get("keys")) > 0
+    expected_keyid = "638c54b8-28c2-4b12-9598-ba12ef610a29"
+
+    assert len(client._jwks["keys"]) == 1
+    assert client._jwks["keys"][0]["kid"] == expected_keyid
diff --git a/tests/_sync/test_gotrue.py b/tests/_sync/test_gotrue.py
index 0f1082b4..0ccfd032 100644
--- a/tests/_sync/test_gotrue.py
+++ b/tests/_sync/test_gotrue.py
@@ -35,4 +35,7 @@ def test_get_claims_fetches_jwks_to_verify_asymmetric_jwt(mocker):
     spy.assert_called_once()
     spy.assert_called_with("GET", ".well-known/jwks.json", xform=unittest.mock.ANY)
 
-    assert len(spy.spy_return.get("keys")) > 0
+    expected_keyid = "638c54b8-28c2-4b12-9598-ba12ef610a29"
+
+    assert len(client._jwks["keys"]) == 1
+    assert client._jwks["keys"][0]["kid"] == expected_keyid

From c7a19033574ca93ad77ce176b1d1d5cbd1d0bbb0 Mon Sep 17 00:00:00 2001
From: Guilherme Souza <guilherme@supabase.io>
Date: Fri, 14 Mar 2025 10:57:21 +0100
Subject: [PATCH 12/15] add jwks cache ttl

---
 supabase_auth/_async/gotrue_client.py | 9 ++++++++-
 supabase_auth/_sync/gotrue_client.py  | 9 ++++++++-
 2 files changed, 16 insertions(+), 2 deletions(-)

diff --git a/supabase_auth/_async/gotrue_client.py b/supabase_auth/_async/gotrue_client.py
index 2669817d..aca2d468 100644
--- a/supabase_auth/_async/gotrue_client.py
+++ b/supabase_auth/_async/gotrue_client.py
@@ -114,7 +114,11 @@ def __init__(
             verify=verify,
             proxy=proxy,
         )
+
         self._jwks: JWKSet = {"keys": []}
+        self._jwks_ttl: float = 600  # 10 minutes
+        self._jwks_cached_at: Optional[float] = None
+
         self._storage_key = storage_key or STORAGE_KEY
         self._auto_refresh_token = auto_refresh_token
         self._persist_session = persist_session
@@ -1170,7 +1174,9 @@ async def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK:
         if jwk:
             return jwk
 
-        if self._jwks:
+        if self._jwks and (
+            self._jwks_cached_at and self._jwks_cached_at + self._jwks_ttl < time()
+        ):
             # try fetching from the cache.
             jwk = next(
                 (jwk for jwk in self._jwks["keys"] if jwk["kid"] == kid),
@@ -1183,6 +1189,7 @@ async def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK:
         response = await self._request("GET", ".well-known/jwks.json", xform=parse_jwks)
         if response:
             self._jwks = response
+            self._jwks_cached_at = time()
 
             # find the signing key
             jwk = next((jwk for jwk in response["keys"] if jwk["kid"] == kid), None)
diff --git a/supabase_auth/_sync/gotrue_client.py b/supabase_auth/_sync/gotrue_client.py
index c161cc22..0817198c 100644
--- a/supabase_auth/_sync/gotrue_client.py
+++ b/supabase_auth/_sync/gotrue_client.py
@@ -114,7 +114,11 @@ def __init__(
             verify=verify,
             proxy=proxy,
         )
+
         self._jwks: JWKSet = {"keys": []}
+        self._jwks_ttl: float = 600  # 10 minutes
+        self._jwks_cached_at: Optional[float] = None
+
         self._storage_key = storage_key or STORAGE_KEY
         self._auto_refresh_token = auto_refresh_token
         self._persist_session = persist_session
@@ -1164,7 +1168,9 @@ def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK:
         if jwk:
             return jwk
 
-        if self._jwks:
+        if self._jwks and (
+            self._jwks_cached_at and self._jwks_cached_at + self._jwks_ttl < time()
+        ):
             # try fetching from the cache.
             jwk = next(
                 (jwk for jwk in self._jwks["keys"] if jwk["kid"] == kid),
@@ -1177,6 +1183,7 @@ def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK:
         response = self._request("GET", ".well-known/jwks.json", xform=parse_jwks)
         if response:
             self._jwks = response
+            self._jwks_cached_at = time()
 
             # find the signing key
             jwk = next((jwk for jwk in response["keys"] if jwk["kid"] == kid), None)

From 2a4bad56b6d0c32fe4f0d681a2c4a81ebceb4bb9 Mon Sep 17 00:00:00 2001
From: Guilherme Souza <guilherme@supabase.io>
Date: Fri, 14 Mar 2025 11:22:29 +0100
Subject: [PATCH 13/15] jwks cache ttl tests

---
 supabase_auth/_async/gotrue_client.py | 16 +++++++-------
 supabase_auth/_sync/gotrue_client.py  | 16 +++++++-------
 tests/_async/test_gotrue.py           | 32 +++++++++++++++++++++++++++
 tests/_sync/test_gotrue.py            | 32 +++++++++++++++++++++++++++
 4 files changed, 80 insertions(+), 16 deletions(-)

diff --git a/supabase_auth/_async/gotrue_client.py b/supabase_auth/_async/gotrue_client.py
index aca2d468..14baf2a7 100644
--- a/supabase_auth/_async/gotrue_client.py
+++ b/supabase_auth/_async/gotrue_client.py
@@ -1,9 +1,9 @@
 from __future__ import annotations
 
+import time
 from contextlib import suppress
 from functools import partial
 from json import loads
-from time import time
 from typing import Callable, Dict, List, Optional, Tuple
 from urllib.parse import parse_qs, urlencode, urlparse
 from uuid import uuid4
@@ -621,7 +621,7 @@ async def get_session(self) -> Optional[Session]:
             current_session = self._in_memory_session
         if not current_session:
             return None
-        time_now = round(time())
+        time_now = round(time.time())
         has_expired = (
             current_session.expires_at <= time_now + EXPIRY_MARGIN
             if current_session.expires_at
@@ -682,7 +682,7 @@ async def set_session(self, access_token: str, refresh_token: str) -> AuthRespon
         The current session that minimally contains an access token,
         refresh token and a user.
         """
-        time_now = round(time())
+        time_now = round(time.time())
         expires_at = time_now
         has_expired = True
         session: Optional[Session] = None
@@ -959,7 +959,7 @@ async def _get_session_from_url(
         token_type = self._get_param(params, "token_type")
         if not token_type:
             raise AuthImplicitGrantRedirectError("No token_type detected.")
-        time_now = round(time())
+        time_now = round(time.time())
         expires_at = time_now + int(expires_in)
         user = await self.get_user(access_token)
         session = Session(
@@ -982,7 +982,7 @@ async def _recover_and_refresh(self) -> None:
             if raw_session:
                 await self._remove_session()
             return
-        time_now = round(time())
+        time_now = round(time.time())
         expires_at = current_session.expires_at
         if expires_at and expires_at < time_now + EXPIRY_MARGIN:
             refresh_token = current_session.refresh_token
@@ -1034,7 +1034,7 @@ async def _save_session(self, session: Session) -> None:
             self._in_memory_session = session
         expire_at = session.expires_at
         if expire_at:
-            time_now = round(time())
+            time_now = round(time.time())
             expire_in = expire_at - time_now
             refresh_duration_before_expires = (
                 EXPIRY_MARGIN if expire_in > EXPIRY_MARGIN else 0.5
@@ -1175,7 +1175,7 @@ async def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK:
             return jwk
 
         if self._jwks and (
-            self._jwks_cached_at and self._jwks_cached_at + self._jwks_ttl < time()
+            self._jwks_cached_at and time.time() - self._jwks_cached_at < self._jwks_ttl
         ):
             # try fetching from the cache.
             jwk = next(
@@ -1189,7 +1189,7 @@ async def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK:
         response = await self._request("GET", ".well-known/jwks.json", xform=parse_jwks)
         if response:
             self._jwks = response
-            self._jwks_cached_at = time()
+            self._jwks_cached_at = time.time()
 
             # find the signing key
             jwk = next((jwk for jwk in response["keys"] if jwk["kid"] == kid), None)
diff --git a/supabase_auth/_sync/gotrue_client.py b/supabase_auth/_sync/gotrue_client.py
index 0817198c..a5ce08b1 100644
--- a/supabase_auth/_sync/gotrue_client.py
+++ b/supabase_auth/_sync/gotrue_client.py
@@ -1,9 +1,9 @@
 from __future__ import annotations
 
+import time
 from contextlib import suppress
 from functools import partial
 from json import loads
-from time import time
 from typing import Callable, Dict, List, Optional, Tuple
 from urllib.parse import parse_qs, urlencode, urlparse
 from uuid import uuid4
@@ -619,7 +619,7 @@ def get_session(self) -> Optional[Session]:
             current_session = self._in_memory_session
         if not current_session:
             return None
-        time_now = round(time())
+        time_now = round(time.time())
         has_expired = (
             current_session.expires_at <= time_now + EXPIRY_MARGIN
             if current_session.expires_at
@@ -680,7 +680,7 @@ def set_session(self, access_token: str, refresh_token: str) -> AuthResponse:
         The current session that minimally contains an access token,
         refresh token and a user.
         """
-        time_now = round(time())
+        time_now = round(time.time())
         expires_at = time_now
         has_expired = True
         session: Optional[Session] = None
@@ -955,7 +955,7 @@ def _get_session_from_url(
         token_type = self._get_param(params, "token_type")
         if not token_type:
             raise AuthImplicitGrantRedirectError("No token_type detected.")
-        time_now = round(time())
+        time_now = round(time.time())
         expires_at = time_now + int(expires_in)
         user = self.get_user(access_token)
         session = Session(
@@ -978,7 +978,7 @@ def _recover_and_refresh(self) -> None:
             if raw_session:
                 self._remove_session()
             return
-        time_now = round(time())
+        time_now = round(time.time())
         expires_at = current_session.expires_at
         if expires_at and expires_at < time_now + EXPIRY_MARGIN:
             refresh_token = current_session.refresh_token
@@ -1030,7 +1030,7 @@ def _save_session(self, session: Session) -> None:
             self._in_memory_session = session
         expire_at = session.expires_at
         if expire_at:
-            time_now = round(time())
+            time_now = round(time.time())
             expire_in = expire_at - time_now
             refresh_duration_before_expires = (
                 EXPIRY_MARGIN if expire_in > EXPIRY_MARGIN else 0.5
@@ -1169,7 +1169,7 @@ def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK:
             return jwk
 
         if self._jwks and (
-            self._jwks_cached_at and self._jwks_cached_at + self._jwks_ttl < time()
+            self._jwks_cached_at and time.time() - self._jwks_cached_at < self._jwks_ttl
         ):
             # try fetching from the cache.
             jwk = next(
@@ -1183,7 +1183,7 @@ def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK:
         response = self._request("GET", ".well-known/jwks.json", xform=parse_jwks)
         if response:
             self._jwks = response
-            self._jwks_cached_at = time()
+            self._jwks_cached_at = time.time()
 
             # find the signing key
             jwk = next((jwk for jwk in response["keys"] if jwk["kid"] == kid), None)
diff --git a/tests/_async/test_gotrue.py b/tests/_async/test_gotrue.py
index 4c39c29e..53f7a487 100644
--- a/tests/_async/test_gotrue.py
+++ b/tests/_async/test_gotrue.py
@@ -1,3 +1,4 @@
+import time
 import unittest
 
 from .clients import auth_client, auth_client_with_asymmetric_session
@@ -39,3 +40,34 @@ async def test_get_claims_fetches_jwks_to_verify_asymmetric_jwt(mocker):
 
     assert len(client._jwks["keys"]) == 1
     assert client._jwks["keys"][0]["kid"] == expected_keyid
+
+
+async def test_jwks_ttl_cache_behavior(mocker):
+    client = auth_client_with_asymmetric_session()
+
+    spy = mocker.spy(client, "_request")
+
+    # First call should fetch JWKS from endpoint
+    user = (await client.sign_up(mock_user_credentials())).user
+    assert user is not None
+
+    await client.get_claims()
+    spy.assert_called_with("GET", ".well-known/jwks.json", xform=unittest.mock.ANY)
+    first_call_count = spy.call_count
+
+    # Second call within TTL should use cache
+    await client.get_claims()
+    assert spy.call_count == first_call_count  # No additional JWKS request
+
+    # Mock time to be after TTL expiry
+    original_time = time.time
+    try:
+        mock_time = mocker.patch("time.time")
+        mock_time.return_value = original_time() + 601  # TTL is 600 seconds
+
+        # Call after TTL expiry should fetch fresh JWKS
+        await client.get_claims()
+        assert spy.call_count == first_call_count + 1  # One more JWKS request
+    finally:
+        # Restore original time function
+        mocker.patch("time.time", original_time)
diff --git a/tests/_sync/test_gotrue.py b/tests/_sync/test_gotrue.py
index 0ccfd032..0fb77a6e 100644
--- a/tests/_sync/test_gotrue.py
+++ b/tests/_sync/test_gotrue.py
@@ -1,3 +1,4 @@
+import time
 import unittest
 
 from .clients import auth_client, auth_client_with_asymmetric_session
@@ -39,3 +40,34 @@ def test_get_claims_fetches_jwks_to_verify_asymmetric_jwt(mocker):
 
     assert len(client._jwks["keys"]) == 1
     assert client._jwks["keys"][0]["kid"] == expected_keyid
+
+
+def test_jwks_ttl_cache_behavior(mocker):
+    client = auth_client_with_asymmetric_session()
+
+    spy = mocker.spy(client, "_request")
+
+    # First call should fetch JWKS from endpoint
+    user = (client.sign_up(mock_user_credentials())).user
+    assert user is not None
+
+    client.get_claims()
+    spy.assert_called_with("GET", ".well-known/jwks.json", xform=unittest.mock.ANY)
+    first_call_count = spy.call_count
+
+    # Second call within TTL should use cache
+    client.get_claims()
+    assert spy.call_count == first_call_count  # No additional JWKS request
+
+    # Mock time to be after TTL expiry
+    original_time = time.time
+    try:
+        mock_time = mocker.patch("time.time")
+        mock_time.return_value = original_time() + 601  # TTL is 600 seconds
+
+        # Call after TTL expiry should fetch fresh JWKS
+        client.get_claims()
+        assert spy.call_count == first_call_count + 1  # One more JWKS request
+    finally:
+        # Restore original time function
+        mocker.patch("time.time", original_time)

From 6a90f66a526a53029c566743fe3663bb4ffcd6e8 Mon Sep 17 00:00:00 2001
From: Guilherme Souza <guilherme@supabase.io>
Date: Tue, 18 Mar 2025 15:01:28 -0300
Subject: [PATCH 14/15] fix base64 token validation and add tests for
 decode_jwt

---
 supabase_auth/_async/gotrue_client.py |  12 +--
 supabase_auth/_sync/gotrue_client.py  |  12 +--
 supabase_auth/helpers.py              |   6 +-
 tests/_async/test_gotrue.py           | 120 +++++++++++++++++++++++++-
 tests/_sync/test_gotrue.py            | 120 +++++++++++++++++++++++++-
 5 files changed, 245 insertions(+), 25 deletions(-)

diff --git a/supabase_auth/_async/gotrue_client.py b/supabase_auth/_async/gotrue_client.py
index 14baf2a7..6c450839 100644
--- a/supabase_auth/_async/gotrue_client.py
+++ b/supabase_auth/_async/gotrue_client.py
@@ -58,7 +58,6 @@
     AuthResponse,
     ClaimsResponse,
     CodeExchangeParams,
-    DecodedJWTDict,
     IdentitiesResponse,
     JWKSet,
     MFAChallengeAndVerifyParams,
@@ -687,7 +686,7 @@ async def set_session(self, access_token: str, refresh_token: str) -> AuthRespon
         has_expired = True
         session: Optional[Session] = None
         if access_token and access_token.split(".")[1]:
-            payload = self._decode_jwt(access_token)
+            payload = decode_jwt(access_token)["payload"]
             exp = payload.get("exp")
             if exp:
                 expires_at = int(exp)
@@ -899,7 +898,7 @@ async def _get_authenticator_assurance_level(
                 next_level=None,
                 current_authentication_methods=[],
             )
-        payload = self._decode_jwt(session.access_token)
+        payload = decode_jwt(session.access_token)["payload"]
         current_level: Optional[AuthenticatorAssuranceLevels] = None
         if payload.get("aal"):
             current_level = payload.get("aal")
@@ -1137,13 +1136,6 @@ async def _get_url_for_provider(
         query = urlencode(params)
         return f"{url}?{query}", params
 
-    def _decode_jwt(self, jwt: str) -> DecodedJWTDict:
-        """
-        Decodes a JWT (without performing any validation).
-        """
-        decoded = decode_jwt(jwt)
-        return decoded["payload"]
-
     async def exchange_code_for_session(self, params: CodeExchangeParams):
         code_verifier = params.get("code_verifier") or await self._storage.get_item(
             f"{self._storage_key}-code-verifier"
diff --git a/supabase_auth/_sync/gotrue_client.py b/supabase_auth/_sync/gotrue_client.py
index a5ce08b1..fabae947 100644
--- a/supabase_auth/_sync/gotrue_client.py
+++ b/supabase_auth/_sync/gotrue_client.py
@@ -58,7 +58,6 @@
     AuthResponse,
     ClaimsResponse,
     CodeExchangeParams,
-    DecodedJWTDict,
     IdentitiesResponse,
     JWKSet,
     MFAChallengeAndVerifyParams,
@@ -685,7 +684,7 @@ def set_session(self, access_token: str, refresh_token: str) -> AuthResponse:
         has_expired = True
         session: Optional[Session] = None
         if access_token and access_token.split(".")[1]:
-            payload = self._decode_jwt(access_token)
+            payload = decode_jwt(access_token)["payload"]
             exp = payload.get("exp")
             if exp:
                 expires_at = int(exp)
@@ -895,7 +894,7 @@ def _get_authenticator_assurance_level(
                 next_level=None,
                 current_authentication_methods=[],
             )
-        payload = self._decode_jwt(session.access_token)
+        payload = decode_jwt(session.access_token)["payload"]
         current_level: Optional[AuthenticatorAssuranceLevels] = None
         if payload.get("aal"):
             current_level = payload.get("aal")
@@ -1131,13 +1130,6 @@ def _get_url_for_provider(
         query = urlencode(params)
         return f"{url}?{query}", params
 
-    def _decode_jwt(self, jwt: str) -> DecodedJWTDict:
-        """
-        Decodes a JWT (without performing any validation).
-        """
-        decoded = decode_jwt(jwt)
-        return decoded["payload"]
-
     def exchange_code_for_session(self, params: CodeExchangeParams):
         code_verifier = params.get("code_verifier") or self._storage.get_item(
             f"{self._storage_key}-code-verifier"
diff --git a/supabase_auth/helpers.py b/supabase_auth/helpers.py
index 9aebab0d..9407e630 100644
--- a/supabase_auth/helpers.py
+++ b/supabase_auth/helpers.py
@@ -229,9 +229,9 @@ def decode_jwt(token: str) -> DecodedJWT:
         raise AuthInvalidJwtError("Invalid JWT structure")
 
     # regex check for base64url
-    # for part in parts:
-    #     if not re.match(BASE64URL_REGEX, part):
-    #         raise AuthInvalidJwtError("JWT not in base64url format")
+    for part in parts:
+        if not re.match(BASE64URL_REGEX, part, re.IGNORECASE):
+            raise AuthInvalidJwtError("JWT not in base64url format")
 
     return DecodedJWT(
         header=JWTHeader(**loads(str_from_base64url(parts[0]))),
diff --git a/tests/_async/test_gotrue.py b/tests/_async/test_gotrue.py
index 53f7a487..2a14a562 100644
--- a/tests/_async/test_gotrue.py
+++ b/tests/_async/test_gotrue.py
@@ -1,7 +1,13 @@
 import time
 import unittest
 
-from .clients import auth_client, auth_client_with_asymmetric_session
+import pytest
+from jwt import encode
+
+from supabase_auth.errors import AuthInvalidJwtError, AuthSessionMissingError
+from supabase_auth.helpers import decode_jwt
+
+from .clients import GOTRUE_JWT_SECRET, auth_client, auth_client_with_asymmetric_session
 from .utils import mock_user_credentials
 
 
@@ -71,3 +77,115 @@ async def test_jwks_ttl_cache_behavior(mocker):
     finally:
         # Restore original time function
         mocker.patch("time.time", original_time)
+
+
+async def test_set_session_with_valid_tokens():
+    client = auth_client()
+    credentials = mock_user_credentials()
+
+    # First sign up to get valid tokens
+    signup_response = await client.sign_up(
+        {
+            "email": credentials.get("email"),
+            "password": credentials.get("password"),
+        }
+    )
+    assert signup_response.session is not None
+
+    # Get the tokens from the signup response
+    access_token = signup_response.session.access_token
+    refresh_token = signup_response.session.refresh_token
+
+    # Clear the session
+    await client._remove_session()
+
+    # Set the session with the tokens
+    response = await client.set_session(access_token, refresh_token)
+
+    # Verify the response
+    assert response.session is not None
+    assert response.session.access_token == access_token
+    assert response.session.refresh_token == refresh_token
+    assert response.user is not None
+    assert response.user.email == credentials.get("email")
+
+
+async def test_set_session_with_expired_token():
+    client = auth_client()
+    credentials = mock_user_credentials()
+
+    # First sign up to get valid tokens
+    signup_response = await client.sign_up(
+        {
+            "email": credentials.get("email"),
+            "password": credentials.get("password"),
+        }
+    )
+    assert signup_response.session is not None
+
+    # Get the tokens from the signup response
+    access_token = signup_response.session.access_token
+    refresh_token = signup_response.session.refresh_token
+
+    # Clear the session
+    await client._remove_session()
+
+    # Create an expired token by modifying the JWT
+    expired_token = access_token.split(".")
+    payload = decode_jwt(access_token)["payload"]
+    payload["exp"] = int(time.time()) - 3600  # Set expiry to 1 hour ago
+    expired_token[1] = encode(payload, GOTRUE_JWT_SECRET, algorithm="HS256").split(".")[
+        1
+    ]
+    expired_access_token = ".".join(expired_token)
+
+    # Set the session with the expired token
+    response = await client.set_session(expired_access_token, refresh_token)
+
+    # Verify the response has a new access token (refreshed)
+    assert response.session is not None
+    assert response.session.access_token != expired_access_token
+    assert response.session.refresh_token != refresh_token
+    assert response.user is not None
+    assert response.user.email == credentials.get("email")
+
+
+async def test_set_session_without_refresh_token():
+    client = auth_client()
+    credentials = mock_user_credentials()
+
+    # First sign up to get valid tokens
+    signup_response = await client.sign_up(
+        {
+            "email": credentials.get("email"),
+            "password": credentials.get("password"),
+        }
+    )
+    assert signup_response.session is not None
+
+    # Get the access token from the signup response
+    access_token = signup_response.session.access_token
+
+    # Clear the session
+    await client._remove_session()
+
+    # Create an expired token
+    expired_token = access_token.split(".")
+    payload = decode_jwt(access_token)["payload"]
+    payload["exp"] = int(time.time()) - 3600  # Set expiry to 1 hour ago
+    expired_token[1] = encode(payload, GOTRUE_JWT_SECRET, algorithm="HS256").split(".")[
+        1
+    ]
+    expired_access_token = ".".join(expired_token)
+
+    # Try to set the session with an expired token but no refresh token
+    with pytest.raises(AuthSessionMissingError):
+        await client.set_session(expired_access_token, "")
+
+
+async def test_set_session_with_invalid_token():
+    client = auth_client()
+
+    # Try to set the session with invalid tokens
+    with pytest.raises(AuthInvalidJwtError):
+        await client.set_session("invalid.token.here", "invalid_refresh_token")
diff --git a/tests/_sync/test_gotrue.py b/tests/_sync/test_gotrue.py
index 0fb77a6e..b68fc9e2 100644
--- a/tests/_sync/test_gotrue.py
+++ b/tests/_sync/test_gotrue.py
@@ -1,7 +1,13 @@
 import time
 import unittest
 
-from .clients import auth_client, auth_client_with_asymmetric_session
+import pytest
+from jwt import encode
+
+from supabase_auth.errors import AuthInvalidJwtError, AuthSessionMissingError
+from supabase_auth.helpers import decode_jwt
+
+from .clients import GOTRUE_JWT_SECRET, auth_client, auth_client_with_asymmetric_session
 from .utils import mock_user_credentials
 
 
@@ -71,3 +77,115 @@ def test_jwks_ttl_cache_behavior(mocker):
     finally:
         # Restore original time function
         mocker.patch("time.time", original_time)
+
+
+def test_set_session_with_valid_tokens():
+    client = auth_client()
+    credentials = mock_user_credentials()
+
+    # First sign up to get valid tokens
+    signup_response = client.sign_up(
+        {
+            "email": credentials.get("email"),
+            "password": credentials.get("password"),
+        }
+    )
+    assert signup_response.session is not None
+
+    # Get the tokens from the signup response
+    access_token = signup_response.session.access_token
+    refresh_token = signup_response.session.refresh_token
+
+    # Clear the session
+    client._remove_session()
+
+    # Set the session with the tokens
+    response = client.set_session(access_token, refresh_token)
+
+    # Verify the response
+    assert response.session is not None
+    assert response.session.access_token == access_token
+    assert response.session.refresh_token == refresh_token
+    assert response.user is not None
+    assert response.user.email == credentials.get("email")
+
+
+def test_set_session_with_expired_token():
+    client = auth_client()
+    credentials = mock_user_credentials()
+
+    # First sign up to get valid tokens
+    signup_response = client.sign_up(
+        {
+            "email": credentials.get("email"),
+            "password": credentials.get("password"),
+        }
+    )
+    assert signup_response.session is not None
+
+    # Get the tokens from the signup response
+    access_token = signup_response.session.access_token
+    refresh_token = signup_response.session.refresh_token
+
+    # Clear the session
+    client._remove_session()
+
+    # Create an expired token by modifying the JWT
+    expired_token = access_token.split(".")
+    payload = decode_jwt(access_token)["payload"]
+    payload["exp"] = int(time.time()) - 3600  # Set expiry to 1 hour ago
+    expired_token[1] = encode(payload, GOTRUE_JWT_SECRET, algorithm="HS256").split(".")[
+        1
+    ]
+    expired_access_token = ".".join(expired_token)
+
+    # Set the session with the expired token
+    response = client.set_session(expired_access_token, refresh_token)
+
+    # Verify the response has a new access token (refreshed)
+    assert response.session is not None
+    assert response.session.access_token != expired_access_token
+    assert response.session.refresh_token != refresh_token
+    assert response.user is not None
+    assert response.user.email == credentials.get("email")
+
+
+def test_set_session_without_refresh_token():
+    client = auth_client()
+    credentials = mock_user_credentials()
+
+    # First sign up to get valid tokens
+    signup_response = client.sign_up(
+        {
+            "email": credentials.get("email"),
+            "password": credentials.get("password"),
+        }
+    )
+    assert signup_response.session is not None
+
+    # Get the access token from the signup response
+    access_token = signup_response.session.access_token
+
+    # Clear the session
+    client._remove_session()
+
+    # Create an expired token
+    expired_token = access_token.split(".")
+    payload = decode_jwt(access_token)["payload"]
+    payload["exp"] = int(time.time()) - 3600  # Set expiry to 1 hour ago
+    expired_token[1] = encode(payload, GOTRUE_JWT_SECRET, algorithm="HS256").split(".")[
+        1
+    ]
+    expired_access_token = ".".join(expired_token)
+
+    # Try to set the session with an expired token but no refresh token
+    with pytest.raises(AuthSessionMissingError):
+        client.set_session(expired_access_token, "")
+
+
+def test_set_session_with_invalid_token():
+    client = auth_client()
+
+    # Try to set the session with invalid tokens
+    with pytest.raises(AuthInvalidJwtError):
+        client.set_session("invalid.token.here", "invalid_refresh_token")

From 66ffe5d432aca0381ac3d3825e9cd1da4830d66d Mon Sep 17 00:00:00 2001
From: Guilherme Souza <guilherme@supabase.io>
Date: Tue, 18 Mar 2025 15:01:45 -0300
Subject: [PATCH 15/15] clean up _refresh_token_timer

---
 supabase_auth/_async/gotrue_client.py | 13 +++++++++++++
 supabase_auth/_sync/gotrue_client.py  | 13 +++++++++++++
 2 files changed, 26 insertions(+)

diff --git a/supabase_auth/_async/gotrue_client.py b/supabase_auth/_async/gotrue_client.py
index 6c450839..02befb90 100644
--- a/supabase_auth/_async/gotrue_client.py
+++ b/supabase_auth/_async/gotrue_client.py
@@ -1237,3 +1237,16 @@ async def get_claims(
 
         # If verification succeeds, decode and return claims
         return ClaimsResponse(claims=payload, headers=header, signature=signature)
+
+    def __del__(self) -> None:
+        """Clean up resources when the client is destroyed."""
+        if self._refresh_token_timer:
+            try:
+                # Try to cancel the timer
+                self._refresh_token_timer.cancel()
+            except:
+                # Ignore errors if event loop is closed or selector is not registered
+                pass
+            finally:
+                # Always set to None to prevent further attempts
+                self._refresh_token_timer = None
diff --git a/supabase_auth/_sync/gotrue_client.py b/supabase_auth/_sync/gotrue_client.py
index fabae947..169ebd0e 100644
--- a/supabase_auth/_sync/gotrue_client.py
+++ b/supabase_auth/_sync/gotrue_client.py
@@ -1231,3 +1231,16 @@ def get_claims(
 
         # If verification succeeds, decode and return claims
         return ClaimsResponse(claims=payload, headers=header, signature=signature)
+
+    def __del__(self) -> None:
+        """Clean up resources when the client is destroyed."""
+        if self._refresh_token_timer:
+            try:
+                # Try to cancel the timer
+                self._refresh_token_timer.cancel()
+            except:
+                # Ignore errors if event loop is closed or selector is not registered
+                pass
+            finally:
+                # Always set to None to prevent further attempts
+                self._refresh_token_timer = None