Skip to content

Commit

Permalink
Merge 23edee2 into ae3618c
Browse files Browse the repository at this point in the history
  • Loading branch information
olevski committed Jan 25, 2024
2 parents ae3618c + 23edee2 commit 16b36fb
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 38 deletions.
3 changes: 3 additions & 0 deletions renku/ui/service/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from renku.ui.service.logger import service_log
from renku.ui.service.serializers.headers import JWT_TOKEN_SECRET
from renku.ui.service.utils import jwk_client
from renku.ui.service.utils.json_encoder import SvcJSONProvider
from renku.ui.service.views import error_response
from renku.ui.service.views.apispec import apispec_blueprint
Expand Down Expand Up @@ -76,6 +77,8 @@ def create_app(custom_exceptions=True):

app.config["cache"] = cache

app.config["KEYCLOAK_JWK_CLIENT"] = jwk_client()

if not is_test_session_running():
GunicornPrometheusMetrics(app)

Expand Down
75 changes: 39 additions & 36 deletions renku/ui/service/serializers/headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@
import base64
import binascii
import os
from typing import cast

import jwt
from flask import current_app
from marshmallow import Schema, ValidationError, fields, post_load, pre_load
from werkzeug.utils import secure_filename

from renku.ui.service.logger import service_log

JWT_TOKEN_SECRET = os.getenv("RENKU_JWT_TOKEN_SECRET", "bW9menZ3cnh6cWpkcHVuZ3F5aWJycmJn")


Expand Down Expand Up @@ -79,7 +83,7 @@ class RenkuHeaders:

@staticmethod
def decode_token(token):
"""Extract authorization token."""
"""Extract the Gitlab access token form a bearer authorization header value."""
components = token.split(" ")

rfc_compliant = token.lower().startswith("bearer")
Expand All @@ -92,44 +96,43 @@ def decode_token(token):

@staticmethod
def decode_user(data):
"""Extract renku user from a JWT."""
decoded = jwt.decode(data, JWT_TOKEN_SECRET, algorithms=["HS256"], audience="renku")
"""Extract renku user from the Keycloak ID token which is a JWT."""
service_log.info(f"decoding token {data}")
try:
jwk = cast(jwt.PyJWKClient, current_app.config["KEYCLOAK_JWK_CLIENT"])
key = jwk.get_signing_key_from_jwt(data)
service_log.info(f"trying with key {key.key} and algo RS256")
decoded = jwt.decode(data, key=key.key, algorithms=["RS256"], audience="renku")
except jwt.PyJWTError as e:
# NOTE: older tokens used to be signed with HS256 so use this as a backup if the validation with RS256
# above fails. We used to need HS256 because a step that is now removed was generating an ID token and
# signing it from data passed in individual header fields.
service_log.info(f"original error {e}")
service_log.info("trying with HS256")
decoded = jwt.decode(data, JWT_TOKEN_SECRET, algorithms=["HS256"], audience="renku")
return UserIdentityToken().load(decoded)

@staticmethod
def reset_old_headers(data):
"""Process old version of old headers."""
# TODO: This should be removed once support for them is phased out.
if "renku-user-id" in data:
data.pop("renku-user-id")

if "renku-user-fullname" in data and "renku-user-email" in data:
renku_user = {
"aud": ["renku"],
"name": decode_b64(data.pop("renku-user-fullname")),
"email": decode_b64(data.pop("renku-user-email")),
}
renku_user["sub"] = renku_user["email"]
data["renku-user"] = jwt.encode(renku_user, JWT_TOKEN_SECRET, algorithm="HS256")

return data


class IdentityHeaders(Schema):
"""User identity schema."""

@pre_load
def set_fields(self, data, **kwargs):
"""Set fields for serialization."""
# NOTE: We don't process headers which are not meant for determining identity.
# TODO: Remove old headers support once support for them is phased out.
old_keys = ["renku-user-id", "renku-user-fullname", "renku-user-email"]
expected_keys = old_keys + [field.data_key for field in self.fields.values()]

data = {key.lower(): value for key, value in data.items() if key.lower() in expected_keys}
data = RenkuHeaders.reset_old_headers(data)

return data
def lowercase_required_headers(self, data, **kwargs):
# NOTE: App flask headers are immutable and raise an error when modified so we copy them here
output = {}
if "Authorization" in data:
output["authorization"] = data["Authorization"]
elif "authorization" in data:
output["authorization"] = data["authorization"]

if "Renku-User" in data:
output["renku-user"] = data["Renku-User"]
elif "Renku-user" in data:
output["renku-user"] = data["Renku-user"]
elif "renku-user":
output["renku-user"] = data["renku-user"]

return output

@post_load
def set_user(self, data, **kwargs):
Expand All @@ -151,12 +154,12 @@ def set_user(self, data, **kwargs):
class RequiredIdentityHeaders(IdentityHeaders):
"""Identity schema for required headers."""

user_token = fields.String(required=True, data_key="renku-user")
auth_token = fields.String(required=True, data_key="authorization")
user_token = fields.String(required=True, data_key="renku-user") # Keycloak ID token
auth_token = fields.String(required=True, data_key="authorization") # Gitlab access token


class OptionalIdentityHeaders(IdentityHeaders):
"""Identity schema for optional headers."""

user_token = fields.String(data_key="renku-user")
auth_token = fields.String(data_key="authorization")
user_token = fields.String(data_key="renku-user") # Keycloak ID token
auth_token = fields.String(data_key="authorization") # Gitlab access token
53 changes: 51 additions & 2 deletions renku/ui/service/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Renku service utility functions."""
from typing import Optional, overload
import os
import urllib
from time import sleep
from typing import Any, Dict, Optional, overload

from renku.ui.service.config import CACHE_PROJECTS_PATH, CACHE_UPLOADS_PATH
import requests
from jwt import PyJWKClient

from renku.core.util.requests import get
from renku.ui.service.config import CACHE_PROJECTS_PATH, CACHE_UPLOADS_PATH, OIDC_URL
from renku.ui.service.errors import ProgramInternalError
from renku.ui.service.logger import service_log


def make_project_path(user, project):
Expand Down Expand Up @@ -86,3 +95,43 @@ def normalize_git_url(git_url: Optional[str]) -> Optional[str]:
git_url = git_url[: -len(".git")]

return git_url


def oidc_discovery() -> Dict[str, Any]:
"""Query the OIDC discovery endpoint from Keycloak with retries, parse the result with JSON and it."""
retries = 0
max_retries = 30
sleep_seconds = 2
renku_domain = os.environ.get("RENKU_DOMAIN")
if not renku_domain:
raise ProgramInternalError(
error_message="Cannot perform OIDC discovery without the renku domain expected "
"to be found in the RENKU_DOMAIN environment variable."
)
full_oidc_url = f"http://{renku_domain}{OIDC_URL}"
while True:
retries += 1
try:
res: requests.Response = get(full_oidc_url)
except (requests.exceptions.HTTPError, urllib.error.HTTPError) as e:
if not retries < max_retries:
service_log.error("Failed to get OIDC discovery data after all retries - the server cannot start.")
raise e
service_log.info(
f"Failed to get OIDC discovery data from {full_oidc_url}, "
f"sleeping for {sleep_seconds} seconds and retrying"
)
sleep(sleep_seconds)
else:
service_log.info(f"Successfully fetched OIDC discovery data from {full_oidc_url}")
return res.json()


def jwk_client() -> PyJWKClient:
"""Return a JWK client for Keycloak that can be used to provide JWT keys for JWT signature validation."""
oidc_data = oidc_discovery()
jwks_uri = oidc_data.get("jwks_uri")
if not jwks_uri:
raise ProgramInternalError(error_message="Could not find jwks_uri in the OIDC discovery data")
jwk = PyJWKClient(jwks_uri)
return jwk

0 comments on commit 16b36fb

Please sign in to comment.