Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions .github/workflows/github-actions-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, windows-latest]
python: [3.7, 3.8, 3.9]
include:
- os: ubuntu-20.04
python: 3.6
python: [3.8, 3.9]

steps:
- uses: actions/checkout@v4
Expand Down
53 changes: 43 additions & 10 deletions iam_python_sdk/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,30 @@
"""Flask module."""

from functools import wraps
from typing import Optional, Union, List
from flask import current_app, Flask, request
from typing import List, Optional, Union
from urllib.parse import urlparse

from flask import Flask, current_app, request
from flask.helpers import make_response
from flask.wrappers import Response
from urllib.parse import urlparse
from werkzeug.exceptions import HTTPException

from .config import Config
from .cache import Cache
from .client import DefaultClient, NewDefaultClient
from .errors import Error as IAMError, ClientTokenGrantError, GetClientInformationError, StartLocalValidationError, \
TokenRevokedError, UserRevokedError, ValidateAndParseClaimsError, ValidatePermissionError
from .http_errors import InsufficientPermissions, InternalServerError, InvalidRefererHeader, UnauthorizedAccess, \
SubdomainMismatch
from .config import Config
from .errors import ClientTokenGrantError
from .errors import Error as IAMError
from .errors import (GetClientInformationError, StartLocalValidationError,
TokenRevokedError, UserRevokedError,
ValidateAndParseClaimsError, ValidatePermissionError)
from .http_errors import (InsufficientPermissions, InternalServerError,
InvalidRefererHeader, SubdomainMismatch,
UnauthorizedAccess)
from .models import JWTClaims, Permission


# ---------- Exceptions ---------- #


class HTTPError(HTTPException):
def __init__(self, http_code: int, error_code: int, message: str, description: Optional[str] = None) -> None:
super().__init__(description)
Expand Down Expand Up @@ -93,6 +99,7 @@ class IAM:
"""

def __init__(self, app: Union[Flask, None] = None) -> None:
self.client_info_cache = Cache(ttl=60)
self.app = app
if app is not None:
self.init_app(app)
Expand Down Expand Up @@ -214,7 +221,33 @@ def validate_referer_header(self, jwt_claims: JWTClaims) -> bool:
bool: Is referer header valid or not
"""
try:
client_info = self.client.GetClientInformation(jwt_claims.Namespace, jwt_claims.ClientId)
# Cache implementation to handle race conditions during IAM URL changes
# When IAM URL is updated, there might be existing valid JWTs that were
# issued with the old URL. This cache ensures those tokens can still be
# validated during the transition period without making redundant requests
# to IAM for the same client information.

# Create cache key using namespace and client ID from JWT claims
# This combination uniquely identifies the client across IAM URL changes
cache_key = f"{jwt_claims.Namespace}:{jwt_claims.ClientId}"

# Try to get client info from cache first to avoid unnecessary IAM requests
# during the URL transition period. This is particularly important when
# handling multiple requests with JWTs issued under the old URL.
client_info = self.client_info_cache.get(cache_key)

if client_info is None:
# Cache miss - need to fetch from IAM
# This will use the current IAM URL configuration, but the response
# will be cached to handle subsequent requests that might still be
# using JWTs issued with the old URL
client_info = self.client.GetClientInformation(jwt_claims.Namespace, jwt_claims.ClientId)
if client_info:
# Store successful response in cache
# This ensures we can handle subsequent requests with old JWTs
# without making additional IAM requests during the URL transition
self.client_info_cache[cache_key] = client_info

except GetClientInformationError:
return False

Expand Down
238 changes: 238 additions & 0 deletions tests/test_referer_header.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
# Copyright 2025 AccelByte Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from unittest.mock import Mock, patch

from flask import Flask

from iam_python_sdk.client import DefaultClient, JWTClaims
from iam_python_sdk.flask import IAM, HTTPError
from iam_python_sdk.models import ClientInformation


class TestInvalidRefererHeader(unittest.TestCase):
@patch("iam_python_sdk.flask.NewDefaultClient")
def setUp(self, mock_new_client):
"""
Common test setup with mocked client initialization
"""
# Given a Flask application with IAM configuration
self.app = Flask(__name__)
self.app.config.update(
{
"IAM_BASE_URL": "http://iam-test.local",
"IAM_CLIENT_ID": "test-client",
"IAM_CLIENT_SECRET": "test-secret",
"IAM_TOKEN_LOCATIONS": ["cookies"],
"IAM_TOKEN_COOKIE_NAME": "access_token",
"IAM_CSRF_PROTECTION": True,
"IAM_STRICT_REFERER": True,
}
)

# Mock the client initialization
self.mock_client = Mock(spec=DefaultClient)
mock_new_client.return_value = self.mock_client

# Initialize IAM with mocked client
self.iam = IAM(self.app)

# And prepared JWT claims
self.jwt_claims = JWTClaims()
self.jwt_claims.Namespace = "test-namespace"
self.jwt_claims.ClientId = "test-client-id"

# And prepared client information
self.client_info = ClientInformation()
self.client_info.Redirecturi = "https://allowed-domain.com/callback"

def test_missing_referer_header_should_raise_error(self):
"""
Given a request without a referer header
When validating the token
Then it should raise an HTTPError with appropriate error codes
"""
# Given
ctx = self.app.test_request_context(
"/",
headers={"Cookie": "access_token=test-token"},
environ_base={"HTTP_REFERER": ""},
)
self.mock_client.ValidateAndParseClaims.return_value = self.jwt_claims
self.mock_client.GetClientInformation.return_value = self.client_info

# When/Then
with ctx:
with self.assertRaises(HTTPError) as context:
self.iam.validate_token_in_request(validate_referer=True)

error = context.exception
self.assertEqual(error.code, 401)
self.assertEqual(error.error_code, 20023)
self.assertIn("Invalid referrer header", error.description)

def test_wrong_domain_in_referer_header_should_raise_error(self):
"""
Given a request with a referer header from an unauthorized domain
When validating the token
Then it should raise an HTTPError with appropriate error codes
"""
# Given
ctx = self.app.test_request_context(
"/",
headers={"Cookie": "access_token=test-token"},
environ_base={"HTTP_REFERER": "https://malicious-domain.com"},
)
self.mock_client.ValidateAndParseClaims.return_value = self.jwt_claims
self.mock_client.GetClientInformation.return_value = self.client_info

# When/Then
with ctx:
with self.assertRaises(HTTPError) as context:
self.iam.validate_token_in_request(validate_referer=True)

error = context.exception
self.assertEqual(error.code, 401)
self.assertEqual(error.error_code, 20023)
self.assertIn("Invalid referrer header", error.description)

def test_malformed_referer_url_should_raise_error(self):
"""
Given a request with a malformed referer URL
When validating the token
Then it should raise an HTTPError with appropriate error codes
"""
# Given
ctx = self.app.test_request_context(
"/",
headers={"Cookie": "access_token=test-token"},
environ_base={"HTTP_REFERER": "not-a-valid-url"},
)
self.mock_client.ValidateAndParseClaims.return_value = self.jwt_claims
self.mock_client.GetClientInformation.return_value = self.client_info

# When/Then
with ctx:
with self.assertRaises(HTTPError) as context:
self.iam.validate_token_in_request(validate_referer=True)

error = context.exception
self.assertEqual(error.code, 401)
self.assertEqual(error.error_code, 20023)
self.assertIn("Invalid referrer header", error.description)

def test_empty_client_redirect_uri_should_allow_any_referer(self):
"""
Given a client with no configured redirect URIs
When validating the token with any referer
Then it should allow the request
"""
# Given
ctx = self.app.test_request_context(
"/",
headers={"Cookie": "access_token=test-token"},
environ_base={"HTTP_REFERER": "https://any-domain.com"},
)
empty_client_info = ClientInformation()
empty_client_info.Redirecturi = ""
self.mock_client.ValidateAndParseClaims.return_value = self.jwt_claims
self.mock_client.GetClientInformation.return_value = empty_client_info

# When
with ctx:
result = self.iam.validate_token_in_request(validate_referer=True)

# Then
self.assertEqual(result, self.jwt_claims)

def test_valid_referer_header_should_succeed(self):
"""
Given a request with a valid referer header
When validating the token
Then it should complete successfully
"""
# Given
ctx = self.app.test_request_context(
"/",
headers={"Cookie": "access_token=test-token"},
environ_base={"HTTP_REFERER": "https://allowed-domain.com/callback"},
)
self.mock_client.ValidateAndParseClaims.return_value = self.jwt_claims
self.mock_client.GetClientInformation.return_value = self.client_info

# When
with ctx:
result = self.iam.validate_token_in_request(validate_referer=True)

# Then
self.assertEqual(result, self.jwt_claims)

def test_changing_referer_mid_process_should_succeed(self):
"""
Given an initial request with a valid referer and cached client info
When the client's redirect URIs are updated and a request comes from a new valid domain
Then the request should succeed using the cached and updated client info
"""
# Given - Initial setup with first domain
initial_client_info = ClientInformation()
initial_client_info.Redirecturi = "https://allowed-domain.com/callback"

initial_ctx = self.app.test_request_context(
"/",
headers={"Cookie": "access_token=test-token"},
environ_base={"HTTP_REFERER": "https://allowed-domain.com/callback"},
)
self.mock_client.ValidateAndParseClaims.return_value = self.jwt_claims
self.mock_client.GetClientInformation.return_value = initial_client_info

# First request should succeed and cache the client info
with initial_ctx:
result = self.iam.validate_token_in_request(validate_referer=True)
self.assertEqual(result, self.jwt_claims)

# Verify GetClientInformation was called once
self.mock_client.GetClientInformation.assert_called_once_with(
self.jwt_claims.Namespace, self.jwt_claims.ClientId
)

# Clear the cache to simulate expiration
self.iam.client_info_cache.clear()

# When - Client info is updated with new redirect URI
updated_client_info = ClientInformation()
updated_client_info.Redirecturi = "https://new-domain.com/callback"
self.mock_client.GetClientInformation.reset_mock()
self.mock_client.GetClientInformation.return_value = updated_client_info

# Create new request from the new domain
new_domain_ctx = self.app.test_request_context(
"/",
headers={"Cookie": "access_token=test-token"},
environ_base={"HTTP_REFERER": "https://new-domain.com/callback"},
)

# Then - The request should succeed with new client info
with new_domain_ctx:
result = self.iam.validate_token_in_request(validate_referer=True)
self.assertEqual(result, self.jwt_claims)

# Verify GetClientInformation was called again (after cache clear)
self.mock_client.GetClientInformation.assert_called_once_with(
self.jwt_claims.Namespace, self.jwt_claims.ClientId
)


if __name__ == "__main__":
unittest.main()