Skip to content

Commit

Permalink
feat(core): use oauth 2.0 device auth grant (#2722)
Browse files Browse the repository at this point in the history
  • Loading branch information
m-alisafaee committed May 24, 2022
1 parent 713b4a4 commit eae254e
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 106 deletions.
107 changes: 71 additions & 36 deletions renku/command/login.py
Expand Up @@ -18,9 +18,8 @@
"""Logging in to a Renku deployment."""

import os
import sys
import time
import urllib
import uuid
import webbrowser
from typing import TYPE_CHECKING

Expand All @@ -37,6 +36,8 @@


CONFIG_SECTION = "http"
KEYCLOAK_REALM = "Renku"
CLIENT_ID = "renku-cli"


def login_command():
Expand Down Expand Up @@ -67,48 +68,82 @@ def _login(endpoint, git_login, yes, client_dispatcher: IClientDispatcher):
else:
raise errors.ParameterError("Cannot find a unique remote URL for project.")

cli_nonce = str(uuid.uuid4())
auth_server_url = _get_url(
parsed_endpoint, path=f"auth/realms/{KEYCLOAK_REALM}/protocol/openid-connect/auth/device"
)

communication.echo(f"Please log in at {parsed_endpoint.geturl()} on your browser.")
try:
response = requests.post(auth_server_url, data={"client_id": CLIENT_ID})
except errors.RequestError as e:
raise errors.RequestError(f"Cannot connect to authorization server at {auth_server_url}.") from e

login_url = _get_url(parsed_endpoint, "/api/auth/login", cli_nonce=cli_nonce)
webbrowser.open_new_tab(login_url)
requests.check_response(response=response)
data = response.json()

server_nonce = communication.prompt("Once completed, enter the security code that you receive at the end")
cli_token_url = _get_url(parsed_endpoint, "/api/auth/cli-token", cli_nonce=cli_nonce, server_nonce=server_nonce)
verification_uri = data.get("verification_uri")
user_code = data.get("user_code")
verification_uri_complete = f"{verification_uri}?user_code={user_code}"

try:
response = requests.get(cli_token_url)
except errors.RequestError as e:
raise errors.OperationError("Cannot get access token from remote host.") from e
communication.echo(
f"Please grant access to '{CLIENT_ID}' in your browser.\n"
f"If a browser window does not open automatically, go to {verification_uri_complete}"
)

if response.status_code == 200:
access_token = response.json().get("access_token")
_store_token(parsed_endpoint.netloc, access_token)
webbrowser.open_new_tab(verification_uri_complete)

if git_login:
_set_git_credential_helper(repository=client.repository, hostname=parsed_endpoint.netloc)
backup_remote_name, backup_exists, remote = create_backup_remote(
repository=client.repository, remote_name=remote_name, url=remote_url # type:ignore
)
if backup_exists:
communication.echo(f"Backup remote '{backup_remote_name}' already exists. Ignoring '--git' flag.")
elif not remote:
communication.error(f"Cannot create backup remote '{backup_remote_name}' for '{remote_url}'")
polling_interval = min(data.get("interval", 5), 5)
token_url = _get_url(parsed_endpoint, path=f"auth/realms/{KEYCLOAK_REALM}/protocol/openid-connect/token")
device_code = data.get("device_code")

while True:
time.sleep(polling_interval)

response = requests.post(
token_url,
data={
"device_code": device_code,
"client_id": CLIENT_ID,
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
},
)
status_code = response.status_code
if status_code == 200:
break
elif status_code == 400:
error = response.json().get("error")

if error == "authorization_pending":
continue
elif error == "slow_down":
polling_interval += 1
elif error == "access_denied":
raise errors.AuthenticationError("Access denied")
elif error == "expired_token":
raise errors.AuthenticationError("Session expired, try again")
else:
_set_renku_url_for_remote(
repository=client.repository,
remote_name=remote_name, # type:ignore
remote_url=remote_url, # type:ignore
hostname=parsed_endpoint.netloc,
)
raise errors.AuthenticationError(f"Invalid error message from server: {response.json()}")
else:
raise errors.AuthenticationError(f"Invalid status code from server: {status_code} - {response.content}")

else:
communication.error(
f"Remote host did not return an access token: {parsed_endpoint.geturl()}, "
f"status code: {response.status_code}"
access_token = response.json().get("access_token")
_store_token(parsed_endpoint.netloc, access_token)

if git_login:
_set_git_credential_helper(repository=client.repository, hostname=parsed_endpoint.netloc)
backup_remote_name, backup_exists, remote = create_backup_remote(
repository=client.repository, remote_name=remote_name, url=remote_url # type:ignore
)
sys.exit(1)
if backup_exists:
communication.echo(f"Backup remote '{backup_remote_name}' already exists. Ignoring '--git' flag.")
elif not remote:
communication.error(f"Cannot create backup remote '{backup_remote_name}' for '{remote_url}'")
else:
_set_renku_url_for_remote(
repository=client.repository,
remote_name=remote_name, # type:ignore
remote_url=remote_url, # type:ignore
hostname=parsed_endpoint.netloc,
)


def _parse_endpoint(endpoint):
Expand All @@ -119,7 +154,7 @@ def _parse_endpoint(endpoint):
return parsed_endpoint


def _get_url(parsed_endpoint, path, **query_args):
def _get_url(parsed_endpoint, path, **query_args) -> str:
query = urllib.parse.urlencode(query_args)
return parsed_endpoint._replace(path=path, query=query).geturl()

Expand Down
8 changes: 5 additions & 3 deletions renku/core/dataset/providers/dataverse.py
Expand Up @@ -546,9 +546,11 @@ def _post(self, url, json=None, data=None, files=None):

@staticmethod
def _check_response(response):
if response.status_code not in [200, 201, 202]:
if response.status_code == 401:
raise errors.AuthenticationError("Access unauthorized - update access token.")
from renku.core.util import requests

try:
requests.check_response(response=response)
except errors.RequestError:
json_res = response.json()
raise errors.ExportError(
"HTTP {} - Cannot export dataset: {}".format(
Expand Down
11 changes: 7 additions & 4 deletions renku/core/dataset/providers/olos.py
Expand Up @@ -231,7 +231,8 @@ def upload_file(self, full_path, path_in_dataset):

return response

def _make_url(self, server_url, api_path, **query_params):
@staticmethod
def _make_url(server_url, api_path, **query_params):
"""Create URL for creating a dataset."""
url_parts = urlparse.urlparse(server_url)

Expand Down Expand Up @@ -259,15 +260,17 @@ def _post(self, url, json=None, data=None, files=None):

@staticmethod
def _check_response(response):
from renku.core.util import requests

if len(response.history) > 0:
raise errors.ExportError(
f"Couldn't execute request to {response.request.url}, got redirected to {response.url}."
"Maybe you mixed up http and https in the server url?"
)

if response.status_code not in [200, 201, 202]:
if response.status_code == 401:
raise errors.AuthenticationError("Access unauthorized - update access token.")
try:
requests.check_response(response=response)
except errors.RequestError:
json_res = response.json()
raise errors.ExportError(
"HTTP {} - Cannot export dataset: {}".format(
Expand Down
29 changes: 25 additions & 4 deletions renku/core/dataset/providers/zenodo.py
Expand Up @@ -28,12 +28,14 @@
import attr
from tqdm import tqdm

from renku.core import errors
from renku.core.dataset.providers.api import ExporterApi, ProviderApi, ProviderRecordSerializerApi
from renku.core.util.file_size import bytes_to_unit

if TYPE_CHECKING:
from renku.core.dataset.providers.models import ProviderDataset


ZENODO_BASE_URL = "https://zenodo.org"
ZENODO_SANDBOX_URL = "https://sandbox.zenodo.org/"

Expand Down Expand Up @@ -358,7 +360,7 @@ def new_deposition(self):
response = requests.post(
url=self.new_deposit_url, params=self.exporter.default_params, json={}, headers=self.exporter.HEADERS
)
requests.check_response(response)
self._check_response(response)

return response

Expand All @@ -371,7 +373,7 @@ def upload_file(self, filepath, path_in_repo):
response = requests.post(
url=self.upload_file_url, params=self.exporter.default_params, data=request_payload, files=file
)
requests.check_response(response)
self._check_response(response)

return response

Expand Down Expand Up @@ -402,7 +404,7 @@ def attach_metadata(self, dataset, tag):
data=json.dumps(request_payload),
headers=self.exporter.HEADERS,
)
requests.check_response(response)
self._check_response(response)

return response

Expand All @@ -411,7 +413,7 @@ def publish_deposition(self, secret):
from renku.core.util import requests

response = requests.post(url=self.publish_url, params=self.exporter.default_params)
requests.check_response(response)
self._check_response(response)

return response

Expand All @@ -420,6 +422,25 @@ def __attrs_post_init__(self):
response = self.new_deposition()
self.id = response.json()["id"]

@staticmethod
def _check_response(response):
from renku.core.util import requests

try:
requests.check_response(response=response)
except errors.RequestError:
if response.status_code == 400:
err_response = response.json()
messages = [
'"{0}" failed with "{1}"'.format(err["field"], err["message"]) for err in err_response["errors"]
]

raise errors.ExportError(
"\n" + "\n".join(messages) + "\nSee `renku dataset edit -h` for details on how to edit" " metadata"
)
else:
raise errors.ExportError(response.content)


@attr.s
class ZenodoExporter(ExporterApi):
Expand Down
23 changes: 8 additions & 15 deletions renku/core/util/requests.py
Expand Up @@ -107,21 +107,14 @@ def get_redirect_url(url) -> str:

def check_response(response):
"""Check for expected response status code."""
if response.status_code not in [200, 201, 202]:
if response.status_code == 401:
raise errors.AuthenticationError("Access unauthorized - update access token.")

if response.status_code == 400:
err_response = response.json()
messages = [
'"{0}" failed with "{1}"'.format(err["field"], err["message"]) for err in err_response["errors"]
]

raise errors.ExportError(
"\n" + "\n".join(messages) + "\nSee `renku dataset edit -h` for details on how to edit" " metadata"
)

raise errors.ExportError(response.content)
if response.status_code in [200, 201, 202]:
return
elif response.status_code == 401:
raise errors.AuthenticationError("Access unauthorized - update access token.")
else:
content = response.content.decode("utf-8") if response.content else ""
message = f"Request failed with code {response.status_code}: {content}"
raise errors.RequestError(message)


def download_file(base_directory: Union[Path, str], url: str, filename, extract, chunk_size=16384):
Expand Down
51 changes: 39 additions & 12 deletions tests/cli/fixtures/cli_gateway.py
Expand Up @@ -16,46 +16,73 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Renku CLI fixtures for Gateway."""

import json
import urllib

import pytest
import responses
from _pytest.monkeypatch import MonkeyPatch

ENDPOINT = "renku.deployment.ch"
ACCESS_TOKEN = "jwt-token"
USER_CODE = "valid_user_code"
DEVICE_CODE = "valid-device-code"


@pytest.fixture(scope="module")
def mock_login():
"""Monkey patch webbrowser module for renku login."""
"""Monkey patch webbrowser package and keycloak endpoints for renku login."""
import webbrowser

with MonkeyPatch().context() as monkey_patch:
monkey_patch.setattr(webbrowser, "open_new_tab", lambda _: None)
monkey_patch.setattr(webbrowser, "open_new_tab", lambda _: True)

with responses.RequestsMock(assert_all_requests_are_fired=False) as requests_mock:

def callback(token):
def func(request):
if request.params.get("server_nonce") == USER_CODE:
def device_callback(request):
data = dict(urllib.parse.parse_qsl(request.body))
if data.get("client_id") != "renku-cli":
return 400, {"Content-Type": "application/json"}, json.dumps({"error": "invalid_client"})

data = {
"verification_uri": f"https://{ENDPOINT}/auth/realms/Renku/device",
"user_code": "ABC-DEF",
"interval": 0,
"device_code": DEVICE_CODE,
}
return 200, {"Content-Type": "application/json"}, json.dumps(data)

def create_token_callback(token):
def token_callback(request):
data = dict(urllib.parse.parse_qsl(request.body))
if (
data.get("device_code") == DEVICE_CODE
and data.get("client_id") == "renku-cli"
and data.get("grant_type") == "urn:ietf:params:oauth:grant-type:device_code"
):
return 200, {"Content-Type": "application/json"}, json.dumps({"access_token": token})

return 404, {"Content-Type": "application/json"}, ""
return 400, {"Content-Type": "application/json"}, ""

return func
return token_callback

requests_mock.add_passthru("https://pypi.org/")

class RequestMockWrapper:
@staticmethod
def add_endpoint_token(endpoint, token):
"""Add a mocked endpoint and its access token."""
def add_device_auth(endpoint, token):
"""Add a mocked endpoint."""
requests_mock.add_callback(
responses.POST,
f"https://{endpoint}/auth/realms/Renku/protocol/openid-connect/auth/device",
callback=device_callback,
)
requests_mock.add_callback(
responses.GET, f"https://{endpoint}/api/auth/cli-token", callback=callback(token)
responses.POST,
f"https://{endpoint}/auth/realms/Renku/protocol/openid-connect/token",
callback=create_token_callback(token),
)

RequestMockWrapper.add_endpoint_token(ENDPOINT, ACCESS_TOKEN)
RequestMockWrapper.add_device_auth(ENDPOINT, ACCESS_TOKEN)

yield RequestMockWrapper

0 comments on commit eae254e

Please sign in to comment.