From 1d9f56c69a6c12fa5bcbfe7e021f4ae4d0b3c406 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Thu, 29 Aug 2019 12:20:37 -0400 Subject: [PATCH 1/4] Add Cloud auth flows to the Client --- src/prefect/agent/agent.py | 2 +- .../agent/kubernetes/resource_manager.py | 2 +- src/prefect/cli/auth.py | 6 +- src/prefect/client/client.py | 296 ++++++++-- tests/agent/test_agent.py | 2 +- tests/agent/test_k8s_agent.py | 2 +- tests/agent/test_local_agent.py | 4 +- tests/agent/test_nomad_agent.py | 2 +- tests/cli/test_auth.py | 130 ++--- tests/cli/test_create.py | 36 +- tests/client/test_client.py | 339 ++---------- tests/client/test_client_auth.py | 514 ++++++++++++++++++ tests/client/test_secrets.py | 7 +- tests/conftest.py | 37 +- 14 files changed, 908 insertions(+), 471 deletions(-) create mode 100644 tests/client/test_client_auth.py diff --git a/src/prefect/agent/agent.py b/src/prefect/agent/agent.py index e94b3a6ea53a..534b75651c56 100644 --- a/src/prefect/agent/agent.py +++ b/src/prefect/agent/agent.py @@ -36,7 +36,7 @@ class Agent: """ def __init__(self) -> None: - self.client = Client(token=config.cloud.agent.get("auth_token")) + self.client = Client(api_token=config.cloud.agent.get("auth_token")) logger = logging.getLogger("agent") logger.setLevel(logging.DEBUG) diff --git a/src/prefect/agent/kubernetes/resource_manager.py b/src/prefect/agent/kubernetes/resource_manager.py index b667436549c7..4d54601ddf7e 100644 --- a/src/prefect/agent/kubernetes/resource_manager.py +++ b/src/prefect/agent/kubernetes/resource_manager.py @@ -25,7 +25,7 @@ def __init__(self) -> None: self.loop_interval = prefect_config.cloud.agent.resource_manager.get( "loop_interval" ) - self.client = Client(token=prefect_config.cloud.agent.get("auth_token")) + self.client = Client(api_token=prefect_config.cloud.agent.get("auth_token")) self.namespace = os.getenv("NAMESPACE", "default") logger = logging.getLogger("resource-manager") diff --git a/src/prefect/cli/auth.py b/src/prefect/cli/auth.py index f89174ab0981..8a0d0de576d8 100644 --- a/src/prefect/cli/auth.py +++ b/src/prefect/cli/auth.py @@ -44,8 +44,7 @@ def login(token): abort=True, ) - client = Client() - client.login(api_token=token) + client = Client(api_token=token) # Verify login obtained a valid api token try: @@ -59,4 +58,7 @@ def login(token): click.secho("Error attempting to communicate with Prefect Cloud", fg="red") return + # save token + client.save_api_token() + click.secho("Login successful", fg="green") diff --git a/src/prefect/client/client.py b/src/prefect/client/client.py index 63379f7a02e4..3628c78053a2 100644 --- a/src/prefect/client/client.py +++ b/src/prefect/client/client.py @@ -1,17 +1,19 @@ -import base64 import datetime import json -import logging import os +import uuid +from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union import pendulum import requests +import toml from requests.adapters import HTTPAdapter from requests.packages.urllib3.util.retry import Retry +from slugify import slugify import prefect -from prefect.utilities.exceptions import AuthorizationError, ClientError +from prefect.utilities.exceptions import ClientError, AuthorizationError from prefect.utilities.graphql import ( EnumValue, GraphQLResult, @@ -62,36 +64,36 @@ class Client: token will only be present in the current context. Args: - - graphql_server (str, optional): the URL to send all GraphQL requests + - api_server (str, optional): the URL to send all GraphQL requests to; if not provided, will be pulled from `cloud.graphql` config var - - token (str, optional): a Prefect Cloud auth token for communication; if not - provided, will be pulled from `cloud.auth_token` config var + - api_token (str, optional): a Prefect Cloud API token, taken from + `config.cloud.auth_token` if not provided. If this token is USER-scoped, it may + be used to log in to any tenant that the user is a member of. In that case, + ephemeral JWTs will be loaded as necessary. Otherwise, the API token itself + will be used as authorization. """ - def __init__(self, graphql_server: str = None, token: str = None): + def __init__(self, api_server: str = None, api_token: str = None): + self._access_token = None + self._refresh_token = None + self._access_token_expires_at = pendulum.now() + self._active_tenant_id = None - if not graphql_server: - graphql_server = prefect.config.cloud.get("graphql") - self.graphql_server = graphql_server + # store api server + self.api_server = api_server or prefect.config.cloud.get("graphql") - token = token or prefect.config.cloud.get("auth_token", None) + # store api token + self._api_token = api_token or prefect.config.cloud.get("auth_token", None) - self.token_is_local = False - if token is None: - if os.path.exists(self.local_token_path): - with open(self.local_token_path, "r") as f: - token = f.read() or None - self.token_is_local = True - - self.token = token + # if no api token was passed, attempt to load state from local storage + if not self._api_token: + settings = self._load_local_settings() + self._api_token = settings.get("api_token") - @property - def local_token_path(self) -> str: - """ - Returns the local token path corresponding to the provided graphql_server - """ - graphql_server = (self.graphql_server or "").replace("/", "_") - return os.path.expanduser("~/.prefect/tokens/{}".format(graphql_server)) + if self._api_token: + self._active_tenant_id = settings.get("active_tenant_id") + if self._active_tenant_id: + self.login_to_tenant(tenant_id=self._active_tenant_id) # ------------------------------------------------------------------------- # Utilities @@ -102,6 +104,7 @@ def get( server: str = None, headers: dict = None, params: Dict[str, JSONLike] = None, + token: str = None, ) -> dict: """ Convenience function for calling the Prefect API with token auth and GET request @@ -110,15 +113,21 @@ def get( - path (str): the path of the API url. For example, to GET http://prefect-server/v1/auth/login, path would be 'auth/login'. - server (str, optional): the server to send the GET request to; - defaults to `self.graphql_server` + defaults to `self.api_server` - headers (dict, optional): Headers to pass with the request - params (dict): GET parameters + - token (str): an auth token. If not supplied, the `client.access_token` is used. Returns: - dict: Dictionary representation of the request made """ response = self._request( - method="GET", path=path, params=params, server=server, headers=headers + method="GET", + path=path, + params=params, + server=server, + headers=headers, + token=token, ) if response.text: return response.json() @@ -131,6 +140,7 @@ def post( server: str = None, headers: dict = None, params: Dict[str, JSONLike] = None, + token: str = None, ) -> dict: """ Convenience function for calling the Prefect API with token auth and POST request @@ -139,15 +149,21 @@ def post( - path (str): the path of the API url. For example, to POST http://prefect-server/v1/auth/login, path would be 'auth/login'. - server (str, optional): the server to send the POST request to; - defaults to `self.graphql_server` + defaults to `self.api_server` - headers(dict): headers to pass with the request - params (dict): POST parameters + - token (str): an auth token. If not supplied, the `client.access_token` is used. Returns: - dict: Dictionary representation of the request made """ response = self._request( - method="POST", path=path, params=params, server=server, headers=headers + method="POST", + path=path, + params=params, + server=server, + headers=headers, + token=token, ) if response.text: return response.json() @@ -160,6 +176,7 @@ def graphql( raise_on_error: bool = True, headers: Dict[str, str] = None, variables: Dict[str, JSONLike] = None, + token: str = None, ) -> GraphQLResult: """ Convenience function for running queries against the Prefect GraphQL API @@ -173,6 +190,7 @@ def graphql( request - variables (dict): Variables to be filled into a query with the key being equivalent to the variables that are accepted by the query + - token (str): an auth token. If not supplied, the `client.access_token` is used. Returns: - dict: Data returned from the GraphQL query @@ -182,12 +200,15 @@ def graphql( """ result = self.post( path="", - server=self.graphql_server, + server=self.api_server, headers=headers, params=dict(query=parse_graphql(query), variables=json.dumps(variables)), + token=token, ) if raise_on_error and "errors" in result: + if "Malformed Authorization header" in str(result["errors"]): + raise AuthorizationError(result["errors"]) raise ClientError(result["errors"]) else: return as_nested_dict(result, GraphQLResult) # type: ignore @@ -199,6 +220,7 @@ def _request( params: Dict[str, JSONLike] = None, server: str = None, headers: dict = None, + token: str = None, ) -> "requests.models.Response": """ Runs any specified request (GET, POST, DELETE) against the server @@ -210,6 +232,7 @@ def _request( - server (str, optional): The server to make requests against, base API server is used if not specified - headers (dict, optional): Headers to pass with the request + - token (str): an auth token. If not supplied, the `client.access_token` is used. Returns: - requests.models.Response: The response returned from the request @@ -220,18 +243,20 @@ def _request( - requests.HTTPError: if a status code is returned that is not `200` or `401` """ if server is None: - server = self.graphql_server + server = self.api_server assert isinstance(server, str) # mypy assert - if self.token is None: - raise AuthorizationError("No token found; call Client.login() to set one.") + if token is None: + token = self.get_auth_token() url = os.path.join(server, path.lstrip("/")).rstrip("/") params = params or {} headers = headers or {} - headers.update({"Authorization": "Bearer {}".format(self.token)}) + if token: + headers["Authorization"] = "Bearer {}".format(token) + session = requests.Session() retries = Retry( total=6, @@ -258,33 +283,196 @@ def _request( # Auth # ------------------------------------------------------------------------- - def login(self, api_token: str) -> None: + def _local_settings_path(self) -> Path: + """ + Returns the local settings directory corresponding to the current API servers + """ + path = "{home}/client/{server}".format( + home=prefect.config.home_dir, + server=slugify(self.api_server, regex_pattern=r"[^-\.a-z0-9]+"), + ) + return Path(os.path.expanduser(path)) / "settings.toml" + + def _save_local_settings(self, settings: dict) -> None: + """ + Writes settings to local storage + """ + self._local_settings_path().parent.mkdir(exist_ok=True, parents=True) + with self._local_settings_path().open("w+") as f: + toml.dump(settings, f) + + def _load_local_settings(self) -> dict: + """ + Loads settings from local storage + """ + if self._local_settings_path().exists(): + with self._local_settings_path().open("r") as f: + return toml.load(f) # type: ignore + return {} + + def save_api_token(self) -> None: + """ + Saves the API token in local storage. + """ + settings = self._load_local_settings() + settings["api_token"] = self._api_token + self._save_local_settings(settings) + + def get_auth_token(self) -> str: + """ + Returns an auth token: + - if no explicit access token is stored, returns the api token + - if there is an access token: + - if there's a refresh token and the access token expires in the next 30 seconds, + then we refresh the access token and store the result + - return the access token + + Returns: + - str: the access token + """ + if not self._access_token: + return self._api_token + + expiration = self._access_token_expires_at or pendulum.now() + if self._refresh_token and pendulum.now().add(seconds=30) > expiration: + self._refresh_access_token() + + return self._access_token + + def get_available_tenants(self) -> List[Dict]: + """ + Returns a list of available tenants. + + NOTE: this should only be called by users who have provided a USER-scoped API token. + + Returns: + - List[Dict]: a list of dictionaries containing the id, slug, and name of + available tenants + """ + result = self.graphql( + {"query": {"tenant(order_by: {slug: asc})": {"id", "slug", "name"}}}, + # use the API token to see all available tenants + token=self._api_token, + ) # type: ignore + return result.data.tenant # type: ignore + + def login_to_tenant(self, tenant_slug: str = None, tenant_id: str = None) -> bool: """ - Logs in to Prefect Cloud with an API token. The token is written to local storage - so it persists across Prefect sessions. + Log in to a specific tenant + + NOTE: this should only be called by users who have provided a USER-scoped API token. Args: - - api_token (str): a Prefect Cloud API token + - tenant_slug (str): the tenant's slug + - tenant_id (str): the tenant's id + + Returns: + - bool: True if the login was successful Raises: - - AuthorizationError if unable to login to the server (request does not return `200`) - """ - if not os.path.exists(os.path.dirname(self.local_token_path)): - os.makedirs(os.path.dirname(self.local_token_path)) - with open(self.local_token_path, "w+") as f: - f.write(api_token) - self.token = api_token - self.token_is_local = True + - ValueError: if at least one of `tenant_slug` or `tenant_id` isn't provided + - ValueError: if the `tenant_id` is not a valid UUID + - ValueError: if no matching tenants are found - def logout(self) -> None: """ - Deletes the token from this client, and removes it from local storage. + + if tenant_slug is None and tenant_id is None: + raise ValueError( + "At least one of `tenant_slug` or `tenant_id` must be provided." + ) + elif tenant_id: + try: + uuid.UUID(tenant_id) + except ValueError: + raise ValueError("The `tenant_id` must be a valid UUID.") + + tenant = self.graphql( + { + "query($slug: String, $id: uuid)": { + "tenant(where: {slug: { _eq: $slug }, id: { _eq: $id } })": {"id"} + } + }, + variables=dict(slug=tenant_slug, id=tenant_id), + # use the API token to query the tenant + token=self._api_token, + ) # type: ignore + if not tenant.data.tenant: # type: ignore + raise ValueError("No matching tenants found.") + + tenant_id = tenant.data.tenant[0].id # type: ignore + + payload = self.graphql( + { + "mutation($input: switchTenantInput!)": { + "switchTenant(input: $input)": { + "accessToken", + "expiresIn", + "refreshToken", + } + } + }, + variables=dict(input=dict(tenantId=tenant_id)), + # Use the API token to switch tenants + token=self._api_token, + ) # type: ignore + self._access_token = payload.data.switchTenant.accessToken # type: ignore + self._access_token_expires_at = pendulum.now().add( + seconds=payload.data.switchTenant.expiresIn # type: ignore + ) + self._refresh_token = payload.data.switchTenant.refreshToken # type: ignore + self._active_tenant_id = tenant_id + + # save the tenant setting + settings = self._load_local_settings() + settings["active_tenant_id"] = self._active_tenant_id + self._save_local_settings(settings) + + return True + + def logout_from_tenant(self) -> None: + self._access_token = None + self._refresh_token = None + self._active_tenant_id = None + + # remove the tenant setting + settings = self._load_local_settings() + settings["active_tenant_id"] = None + self._save_local_settings(settings) + + def _refresh_access_token(self) -> bool: """ - self.token = None - if self.token_is_local: - if os.path.exists(self.local_token_path): - os.remove(self.local_token_path) - self.token_is_local = False + Refresh the client's JWT access token. + + NOTE: this should only be called by users who have provided a USER-scoped API token. + + Returns: + - bool: True if the refresh succeeds + """ + payload = self.graphql( + { + "mutation($input: refreshTokenInput!)": { + "refreshToken(input: $input)": { + "accessToken", + "expiresIn", + "refreshToken", + } + } + }, + variables=dict(input=dict(accessToken=self._access_token)), + # pass the refresh token as the auth header + token=self._refresh_token, + ) # type: ignore + self._access_token = payload.data.refreshToken.accessToken # type: ignore + self._access_token_expires_at = pendulum.now().add( + seconds=payload.data.refreshToken.expiresIn # type: ignore + ) + self._refresh_token = payload.data.refreshToken.refreshToken # type: ignore + + return True + + # ------------------------------------------------------------------------- + # Actions + # ------------------------------------------------------------------------- def deploy( self, diff --git a/tests/agent/test_agent.py b/tests/agent/test_agent.py index c24e6a078250..55c6a43424b3 100644 --- a/tests/agent/test_agent.py +++ b/tests/agent/test_agent.py @@ -17,7 +17,7 @@ def test_agent_init(): def test_agent_config_options(): with set_temporary_config({"cloud.agent.auth_token": "TEST_TOKEN"}): agent = Agent() - assert agent.client.token == "TEST_TOKEN" + assert agent.client.get_auth_token() == "TEST_TOKEN" assert agent.logger diff --git a/tests/agent/test_k8s_agent.py b/tests/agent/test_k8s_agent.py index 863d44501c71..209ad4022858 100644 --- a/tests/agent/test_k8s_agent.py +++ b/tests/agent/test_k8s_agent.py @@ -28,7 +28,7 @@ def test_k8s_agent_config_options(monkeypatch): with set_temporary_config({"cloud.agent.auth_token": "TEST_TOKEN"}): agent = KubernetesAgent() assert agent - assert agent.client.token == "TEST_TOKEN" + assert agent.client.get_auth_token() == "TEST_TOKEN" assert agent.logger assert agent.batch_client diff --git a/tests/agent/test_local_agent.py b/tests/agent/test_local_agent.py index fb575b166042..e2aa1430d424 100644 --- a/tests/agent/test_local_agent.py +++ b/tests/agent/test_local_agent.py @@ -23,7 +23,7 @@ def test_local_agent_config_options(monkeypatch): with set_temporary_config({"cloud.agent.auth_token": "TEST_TOKEN"}): agent = LocalAgent() - assert agent.client.token == "TEST_TOKEN" + assert agent.client.get_auth_token() == "TEST_TOKEN" assert agent.logger assert not agent.no_pull assert api.call_args[1]["base_url"] == "unix://var/run/docker.sock" @@ -35,7 +35,7 @@ def test_local_agent_config_options_populated(monkeypatch): with set_temporary_config({"cloud.agent.auth_token": "TEST_TOKEN"}): agent = LocalAgent(base_url="url", no_pull=True) - assert agent.client.token == "TEST_TOKEN" + assert agent.client.get_auth_token() == "TEST_TOKEN" assert agent.logger assert agent.no_pull assert api.call_args[1]["base_url"] == "url" diff --git a/tests/agent/test_nomad_agent.py b/tests/agent/test_nomad_agent.py index 1d4ed797a0dc..35fa4b22eb3a 100644 --- a/tests/agent/test_nomad_agent.py +++ b/tests/agent/test_nomad_agent.py @@ -17,7 +17,7 @@ def test_nomad_agent_config_options(): with set_temporary_config({"cloud.agent.auth_token": "TEST_TOKEN"}): agent = NomadAgent() assert agent - assert agent.client.token == "TEST_TOKEN" + assert agent.client.get_auth_token() == "TEST_TOKEN" assert agent.logger diff --git a/tests/cli/test_auth.py b/tests/cli/test_auth.py index c2e527758491..25bbf8a746de 100644 --- a/tests/cli/test_auth.py +++ b/tests/cli/test_auth.py @@ -25,88 +25,48 @@ def test_auth_help(): assert "Handle Prefect Cloud authorization." in result.output -def test_auth_login(monkeypatch): - - with tempfile.NamedTemporaryFile() as f: - monkeypatch.setattr("prefect.client.Client.local_token_path", f.name) - post = MagicMock( - return_value=MagicMock( - json=MagicMock(return_value=dict(data=dict(tenant="id"))) - ) - ) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) - - with set_temporary_config( - {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} - ): - runner = CliRunner() - result = runner.invoke(auth, ["login", "--token", "test"]) - assert result.exit_code == 0 - assert "Login successful" in result.output - - -def test_auth_login_client_error(monkeypatch): - - with tempfile.NamedTemporaryFile() as f: - monkeypatch.setattr("prefect.client.Client.local_token_path", f.name) - post = MagicMock( - return_value=MagicMock( - json=MagicMock(return_value=dict(errors=dict(error="bad"))) - ) - ) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) - - with set_temporary_config( - {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} - ): - runner = CliRunner() - result = runner.invoke(auth, ["login", "--token", "test"]) - assert result.exit_code == 0 - assert "Error attempting to communicate with Prefect Cloud" in result.output - - -def test_auth_login_confirm(monkeypatch): - - with tempfile.NamedTemporaryFile() as f: - monkeypatch.setattr("prefect.client.Client.local_token_path", f.name) - post = MagicMock( - return_value=MagicMock( - json=MagicMock(return_value=dict(data=dict(hello="hi"))) - ) - ) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) - - with set_temporary_config( - {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} - ): - runner = CliRunner() - result = runner.invoke(auth, ["login", "--token", "test"], input="Y") - assert result.exit_code == 0 - assert "Login successful" in result.output - - -def test_auth_login_not_confirm(monkeypatch): - - with tempfile.NamedTemporaryFile() as f: - monkeypatch.setattr("prefect.client.Client.local_token_path", f.name) - post = MagicMock( - return_value=MagicMock( - json=MagicMock(return_value=dict(data=dict(hello="hi"))) - ) - ) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) - - with set_temporary_config( - {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} - ): - runner = CliRunner() - result = runner.invoke(auth, ["login", "--token", "test"], input="N") - assert result.exit_code == 1 +def test_auth_login(patch_post): + patch_post(dict(data=dict(tenant="id"))) + + with set_temporary_config( + {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} + ): + runner = CliRunner() + result = runner.invoke(auth, ["login", "--token", "test"]) + assert result.exit_code == 0 + assert "Login successful" in result.output + + +def test_auth_login_client_error(patch_post): + patch_post(dict(errors=dict(error="bad"))) + + with set_temporary_config( + {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} + ): + runner = CliRunner() + result = runner.invoke(auth, ["login", "--token", "test"]) + assert result.exit_code == 0 + assert "Error attempting to communicate with Prefect Cloud" in result.output + + +def test_auth_login_confirm(patch_post): + patch_post(dict(data=dict(hello="hi"))) + + with set_temporary_config( + {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} + ): + runner = CliRunner() + result = runner.invoke(auth, ["login", "--token", "test"], input="Y") + assert result.exit_code == 0 + assert "Login successful" in result.output + + +def test_auth_login_not_confirm(patch_post): + patch_post(dict(data=dict(hello="hi"))) + + with set_temporary_config( + {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} + ): + runner = CliRunner() + result = runner.invoke(auth, ["login", "--token", "test"], input="N") + assert result.exit_code == 1 diff --git a/tests/cli/test_create.py b/tests/cli/test_create.py index f26c75803a06..61c6c7978671 100644 --- a/tests/cli/test_create.py +++ b/tests/cli/test_create.py @@ -26,16 +26,8 @@ def test_create_help(): ) -def test_create_project(monkeypatch): - - post = MagicMock( - return_value=MagicMock( - json=MagicMock(return_value=dict(data=dict(createProject=dict(id="id")))) - ) - ) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) +def test_create_project(patch_post): + patch_post(dict(data=dict(createProject=dict(id="id")))) with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} @@ -46,16 +38,8 @@ def test_create_project(monkeypatch): assert "test created" in result.output -def test_create_project_error(monkeypatch): - - post = MagicMock( - return_value=MagicMock( - json=MagicMock(return_value=dict(errors=dict(error="bad"))) - ) - ) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) +def test_create_project_error(patch_post): + patch_post(dict(errors=dict(error="bad"))) with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} @@ -66,16 +50,8 @@ def test_create_project_error(monkeypatch): assert "Error creating project" in result.output -def test_create_project_description(monkeypatch): - - post = MagicMock( - return_value=MagicMock( - json=MagicMock(return_value=dict(data=dict(createProject=dict(id="id")))) - ) - ) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) +def test_create_project_description(patch_post): + patch_post(dict(data=dict(createProject=dict(id="id")))) with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} diff --git a/tests/client/test_client.py b/tests/client/test_client.py index eb9b7b845a4c..54f448cdd37f 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -17,176 +17,10 @@ from prefect.utilities.exceptions import AuthorizationError, ClientError from prefect.utilities.graphql import GraphQLResult, decompress -################################# -##### Client Tests -################################# +def test_client_posts_to_api_server(patch_post): + post = patch_post(dict(success=True)) -def test_client_initializes_from_config(): - with set_temporary_config( - {"cloud.graphql": "graphql_server", "cloud.auth_token": "token"} - ): - client = Client() - assert client.graphql_server == "graphql_server" - assert client.token == "token" - - -def test_client_initializes_and_prioritizes_kwargs(): - with set_temporary_config( - {"cloud.graphql": "graphql_server", "cloud.auth_token": "token"} - ): - client = Client(graphql_server="my-graphql") - assert client.graphql_server == "my-graphql" - assert client.token == "token" - - -def test_client_token_path_depends_on_graphql_server(): - assert Client(graphql_server="a").local_token_path == os.path.expanduser( - "~/.prefect/tokens/a" - ) - - assert Client(graphql_server="b").local_token_path == os.path.expanduser( - "~/.prefect/tokens/b" - ) - - -def test_client_token_initializes_from_file(monkeypatch): - - with tempfile.NamedTemporaryFile() as f: - f.write(b"TOKEN") - f.seek(0) - monkeypatch.setattr("prefect.client.Client.local_token_path", f.name) - - with set_temporary_config({"cloud.auth_token": None}): - client = Client() - assert client.token == "TOKEN" - - -def test_client_token_priotizes_config_over_file(monkeypatch): - with tempfile.NamedTemporaryFile() as f: - f.write(b"TOKEN") - f.seek(0) - monkeypatch.setattr("prefect.client.Client.local_token_path", f.name) - - with set_temporary_config({"cloud.auth_token": "CONFIG-TOKEN"}): - client = Client() - assert client.token == "CONFIG-TOKEN" - - -def test_login_writes_token(monkeypatch): - with tempfile.NamedTemporaryFile() as f: - monkeypatch.setattr("prefect.client.Client.local_token_path", f.name) - - client = Client() - - client.login(api_token="a") - assert f.read() == b"a" - - f.seek(0) - - client.login(api_token="b") - assert f.read() == b"b" - - -def test_login_creates_directories(monkeypatch): - with tempfile.TemporaryDirectory() as tmp: - - f_path = os.path.join(tmp, "a", "b", "c") - - monkeypatch.setattr("prefect.client.Client.local_token_path", f_path) - - client = Client() - - client.login(api_token="a") - - with open(f_path) as f: - assert f.read() == "a" - - -def test_logout_removes_token(monkeypatch): - with tempfile.NamedTemporaryFile(delete=False) as f: - monkeypatch.setattr("prefect.client.Client.local_token_path", f.name) - - client = Client() - - client.login(api_token="a") - assert f.read() == b"a" - - client.logout() - assert not os.path.exists(f.name) - - -def test_client_posts_raises_with_no_token(monkeypatch): - post = MagicMock() - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) - with set_temporary_config( - {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": None} - ): - client = Client() - with pytest.raises(AuthorizationError, match="Client.login"): - result = client.post("/foo/bar") - - -def test_headers_are_passed_to_get(monkeypatch): - get = MagicMock() - session = MagicMock() - session.return_value.get = get - monkeypatch.setattr("requests.Session", session) - with set_temporary_config( - {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} - ): - client = Client() - client.get("/foo/bar", headers={"x": "y", "Authorization": "z"}) - assert get.called - assert get.call_args[1]["headers"] == { - "x": "y", - "Authorization": "Bearer secret_token", - } - - -def test_headers_are_passed_to_post(monkeypatch): - post = MagicMock() - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) - with set_temporary_config( - {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} - ): - client = Client() - client.post("/foo/bar", headers={"x": "y", "Authorization": "z"}) - assert post.called - assert post.call_args[1]["headers"] == { - "x": "y", - "Authorization": "Bearer secret_token", - } - - -def test_headers_are_passed_to_graphql(monkeypatch): - post = MagicMock() - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) - with set_temporary_config( - {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} - ): - client = Client() - client.graphql("query {}", headers={"x": "y", "Authorization": "z"}) - assert post.called - assert post.call_args[1]["headers"] == { - "x": "y", - "Authorization": "Bearer secret_token", - } - - -def test_client_posts_to_graphql_server(monkeypatch): - post = MagicMock( - return_value=MagicMock(json=MagicMock(return_value=dict(success=True))) - ) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} ): @@ -197,15 +31,9 @@ def test_client_posts_to_graphql_server(monkeypatch): assert post.call_args[0][0] == "http://my-cloud.foo/foo/bar" -def test_client_posts_graphql_to_graphql_server(monkeypatch): - post = MagicMock( - return_value=MagicMock( - json=MagicMock(return_value=dict(data=dict(success=True))) - ) - ) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) +def test_client_posts_graphql_to_api_server(patch_post): + post = patch_post(dict(data=dict(success=True))) + with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} ): @@ -217,15 +45,9 @@ def test_client_posts_graphql_to_graphql_server(monkeypatch): ## test actual mutation and query handling -def test_graphql_errors_get_raised(monkeypatch): - post = MagicMock( - return_value=MagicMock( - json=MagicMock(return_value=dict(data="42", errors="GraphQL issue!")) - ) - ) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) +def test_graphql_errors_get_raised(patch_post): + patch_post(dict(data="42", errors="GraphQL issue!")) + with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} ): @@ -235,7 +57,7 @@ def test_graphql_errors_get_raised(monkeypatch): @pytest.mark.parametrize("compressed", [True, False]) -def test_client_deploy(monkeypatch, compressed): +def test_client_deploy(patch_post, compressed): if compressed: response = { "data": { @@ -247,10 +69,8 @@ def test_client_deploy(monkeypatch, compressed): response = { "data": {"project": [{"id": "proj-id"}], "createFlow": {"id": "long-id"}} } - post = MagicMock(return_value=MagicMock(json=MagicMock(return_value=response))) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) + patch_post(response) + with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} ): @@ -263,7 +83,7 @@ def test_client_deploy(monkeypatch, compressed): @pytest.mark.parametrize("compressed", [True, False]) -def test_client_deploy_builds_flow(monkeypatch, compressed): +def test_client_deploy_builds_flow(patch_post, compressed): if compressed: response = { "data": { @@ -275,10 +95,8 @@ def test_client_deploy_builds_flow(monkeypatch, compressed): response = { "data": {"project": [{"id": "proj-id"}], "createFlow": {"id": "long-id"}} } - post = MagicMock(return_value=MagicMock(json=MagicMock(return_value=response))) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) + post = patch_post(response) + with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} ): @@ -303,7 +121,7 @@ def test_client_deploy_builds_flow(monkeypatch, compressed): @pytest.mark.parametrize("compressed", [True, False]) -def test_client_deploy_optionally_avoids_building_flow(monkeypatch, compressed): +def test_client_deploy_optionally_avoids_building_flow(patch_post, compressed): if compressed: response = { "data": { @@ -315,10 +133,8 @@ def test_client_deploy_optionally_avoids_building_flow(monkeypatch, compressed): response = { "data": {"project": [{"id": "proj-id"}], "createFlow": {"id": "long-id"}} } - post = MagicMock(return_value=MagicMock(json=MagicMock(return_value=response))) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) + post = patch_post(response) + with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} ): @@ -342,12 +158,9 @@ def test_client_deploy_optionally_avoids_building_flow(monkeypatch, compressed): assert serialized_flow["storage"] is None -def test_client_deploy_with_bad_proj_name(monkeypatch): - response = {"data": {"project": []}} - post = MagicMock(return_value=MagicMock(json=MagicMock(return_value=response))) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) +def test_client_deploy_with_bad_proj_name(patch_post): + patch_post({"data": {"project": []}}) + with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} ): @@ -359,12 +172,9 @@ def test_client_deploy_with_bad_proj_name(monkeypatch): assert "client.create_project" in str(exc.value) -def test_client_deploy_with_flow_that_cant_be_deserialized(monkeypatch): - response = {"data": {"project": [{"id": "proj-id"}]}} - post = MagicMock(return_value=MagicMock(json=MagicMock(return_value=response))) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) +def test_client_deploy_with_flow_that_cant_be_deserialized(patch_post): + patch_post({"data": {"project": [{"id": "proj-id"}]}}) + with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} ): @@ -385,7 +195,7 @@ def test_client_deploy_with_flow_that_cant_be_deserialized(monkeypatch): client.deploy(flow, project_name="my-default-project", build=False) -def test_get_flow_run_info(monkeypatch): +def test_get_flow_run_info(patch_post): response = { "flow_run_by_pk": { "id": "da344768-5f5d-4eaf-9bca-83815617f713", @@ -424,13 +234,8 @@ def test_get_flow_run_info(monkeypatch): ], } } + post = patch_post(dict(data=response)) - post = MagicMock( - return_value=MagicMock(json=MagicMock(return_value=dict(data=response))) - ) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} ): @@ -448,7 +253,7 @@ def test_get_flow_run_info(monkeypatch): assert result.context is None -def test_get_flow_run_info_with_nontrivial_payloads(monkeypatch): +def test_get_flow_run_info_with_nontrivial_payloads(patch_post): response = { "flow_run_by_pk": { "id": "da344768-5f5d-4eaf-9bca-83815617f713", @@ -487,13 +292,8 @@ def test_get_flow_run_info_with_nontrivial_payloads(monkeypatch): ], } } + post = patch_post(dict(data=response)) - post = MagicMock( - return_value=MagicMock(json=MagicMock(return_value=dict(data=response))) - ) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} ): @@ -515,14 +315,8 @@ def test_get_flow_run_info_with_nontrivial_payloads(monkeypatch): assert result.context["my_val"] == "test" -def test_get_flow_run_info_raises_informative_error(monkeypatch): - response = {"flow_run_by_pk": None} - post = MagicMock( - return_value=MagicMock(json=MagicMock(return_value=dict(data=response))) - ) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) +def test_get_flow_run_info_raises_informative_error(patch_post): + post = patch_post(dict(data={"flow_run_by_pk": None})) with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} ): @@ -531,12 +325,10 @@ def test_get_flow_run_info_raises_informative_error(monkeypatch): client.get_flow_run_info(flow_run_id="74-salt") -def test_set_flow_run_state(monkeypatch): +def test_set_flow_run_state(patch_post): response = {"data": {"setFlowRunState": {"id": 1}}} - post = MagicMock(return_value=MagicMock(json=MagicMock(return_value=response))) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) + post = patch_post(response) + with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} ): @@ -547,15 +339,13 @@ def test_set_flow_run_state(monkeypatch): assert result is None -def test_set_flow_run_state_with_error(monkeypatch): +def test_set_flow_run_state_with_error(patch_post): response = { "data": {"setFlowRunState": None}, "errors": [{"message": "something went wrong"}], } - post = MagicMock(return_value=MagicMock(json=MagicMock(return_value=response))) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) + post = patch_post(response) + with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} ): @@ -564,7 +354,7 @@ def test_set_flow_run_state_with_error(monkeypatch): client.set_flow_run_state(flow_run_id="74-salt", version=0, state=Pending()) -def test_get_task_run_info(monkeypatch): +def test_get_task_run_info(patch_post): response = { "getOrCreateTaskRun": { "task_run": { @@ -586,12 +376,7 @@ def test_get_task_run_info(monkeypatch): } } - post = MagicMock( - return_value=MagicMock(json=MagicMock(return_value=dict(data=response))) - ) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) + post = patch_post(dict(data=response)) with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} ): @@ -607,15 +392,13 @@ def test_get_task_run_info(monkeypatch): assert result.version == 0 -def test_get_task_run_info_with_error(monkeypatch): +def test_get_task_run_info_with_error(patch_post): response = { "data": {"getOrCreateTaskRun": None}, "errors": [{"message": "something went wrong"}], } - post = MagicMock(return_value=MagicMock(json=MagicMock(return_value=response))) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) + post = patch_post(response) + with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} ): @@ -627,13 +410,10 @@ def test_get_task_run_info_with_error(monkeypatch): ) -def test_set_task_run_state(monkeypatch): +def test_set_task_run_state(patch_post): response = {"data": {"setTaskRunState": None}} - post = MagicMock(return_value=MagicMock(json=MagicMock(return_value=response))) + post = patch_post(response) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} ): @@ -645,13 +425,10 @@ def test_set_task_run_state(monkeypatch): assert result is None -def test_set_task_run_state_serializes(monkeypatch): +def test_set_task_run_state_serializes(patch_post): response = {"data": {"setTaskRunState": None}} - post = MagicMock(return_value=MagicMock(json=MagicMock(return_value=response))) + post = patch_post(response) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} ): @@ -664,16 +441,13 @@ def test_set_task_run_state_serializes(monkeypatch): ) -def test_set_task_run_state_with_error(monkeypatch): +def test_set_task_run_state_with_error(patch_post): response = { "data": {"setTaskRunState": None}, "errors": [{"message": "something went wrong"}], } - post = MagicMock(return_value=MagicMock(json=MagicMock(return_value=response))) + post = patch_post(response) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} ): @@ -683,12 +457,8 @@ def test_set_task_run_state_with_error(monkeypatch): client.set_task_run_state(task_run_id="76-salt", version=0, state=Pending()) -def test_write_log_successfully(monkeypatch): - response = {"data": {"writeRunLog": {"success": True}}} - post = MagicMock(return_value=MagicMock(json=MagicMock(return_value=response))) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) +def test_write_log_successfully(patch_post): + patch_post({"data": {"writeRunLog": {"success": True}}}) with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} @@ -698,15 +468,10 @@ def test_write_log_successfully(monkeypatch): assert client.write_run_log(flow_run_id="1") is None -def test_write_log_with_error(monkeypatch): - response = { - "data": {"writeRunLog": None}, - "errors": [{"message": "something went wrong"}], - } - post = MagicMock(return_value=MagicMock(json=MagicMock(return_value=response))) - session = MagicMock() - session.return_value.post = post - monkeypatch.setattr("requests.Session", session) +def test_write_log_with_error(patch_post): + patch_post( + {"data": {"writeRunLog": None}, "errors": [{"message": "something went wrong"}]} + ) with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} diff --git a/tests/client/test_client_auth.py b/tests/client/test_client_auth.py new file mode 100644 index 000000000000..cfb4588eaa1f --- /dev/null +++ b/tests/client/test_client_auth.py @@ -0,0 +1,514 @@ +import uuid +import toml +from pathlib import Path +import tempfile +import datetime +import json +import os +from unittest.mock import MagicMock, mock_open + +import marshmallow +import pendulum +import pytest +import requests + +import prefect +from prefect.client.client import Client, FlowRunInfoResult, TaskRunInfoResult +from prefect.engine.result import NoResult, Result, SafeResult +from prefect.engine.state import Pending +from prefect.utilities.configuration import set_temporary_config +from prefect.utilities.exceptions import AuthorizationError, ClientError +from prefect.utilities.graphql import GraphQLResult, decompress + + +@pytest.fixture +def patch_graphql(monkeypatch): + def patch(response): + post = MagicMock(return_value=MagicMock(json=MagicMock(return_value=response))) + session = MagicMock() + session.return_value.post = post + monkeypatch.setattr("requests.Session", session) + + return patch + + +class TestClientConfig: + def test_client_initializes_from_config(self): + with set_temporary_config( + {"cloud.graphql": "api_server", "cloud.auth_token": "token"} + ): + client = Client() + assert client.api_server == "api_server" + assert client._api_token == "token" + + def test_client_initializes_and_prioritizes_kwargs(self): + with set_temporary_config( + {"cloud.graphql": "api_server", "cloud.auth_token": "token"} + ): + client = Client(api_server="my-graphql") + assert client.api_server == "my-graphql" + assert client._api_token == "token" + + def test_client_settings_path_is_path_object(self): + assert isinstance(Client()._local_settings_path(), Path) + + def test_client_settings_path_depends_on_api_server(self, prefect_home_dir): + path = Client( + api_server="https://a-test-api.prefect.test/subdomain" + )._local_settings_path() + expected = os.path.join( + prefect_home_dir, + "client/https-a-test-api.prefect.test-subdomain/settings.toml", + ) + assert str(path) == expected + + def test_client_settings_path_depends_on_home_dir(self): + with set_temporary_config(dict(home_dir="abc/def")): + path = Client(api_server="xyz")._local_settings_path() + expected = "abc/def/client/xyz/settings.toml" + assert str(path) == os.path.expanduser(expected) + + def test_client_token_initializes_from_file(selfmonkeypatch): + with tempfile.TemporaryDirectory() as tmp: + with set_temporary_config({"home_dir": tmp, "cloud.graphql": "xyz"}): + path = Path(tmp) / "client" / "xyz" / "settings.toml" + path.parent.mkdir(parents=True) + with path.open("w") as f: + toml.dump(dict(api_token="FILE_TOKEN"), f) + + client = Client() + assert client._api_token == "FILE_TOKEN" + + def test_client_token_priotizes_config_over_file(selfmonkeypatch): + with tempfile.TemporaryDirectory() as tmp: + with set_temporary_config( + { + "home_dir": tmp, + "cloud.graphql": "xyz", + "cloud.auth_token": "CONFIG_TOKEN", + } + ): + path = Path(tmp) / "client" / "xyz" / "settings.toml" + path.parent.mkdir(parents=True) + with path.open("w") as f: + toml.dump(dict(api_token="FILE_TOKEN"), f) + + client = Client() + assert client._api_token == "CONFIG_TOKEN" + + def test_client_token_priotizes_arg_over_config(self): + with set_temporary_config({"cloud.auth_token": "CONFIG_TOKEN"}): + client = Client(api_token="ARG_TOKEN") + assert client._api_token == "ARG_TOKEN" + + def test_save_local_settings(self): + with tempfile.TemporaryDirectory() as tmp: + with set_temporary_config({"home_dir": tmp, "cloud.graphql": "xyz"}): + path = Path(tmp) / "client" / "xyz" / "settings.toml" + + client = Client(api_token="a") + client.save_api_token() + with path.open("r") as f: + assert toml.load(f)["api_token"] == "a" + + client = Client(api_token="b") + client.save_api_token() + with path.open("r") as f: + assert toml.load(f)["api_token"] == "b" + + def test_load_local_api_token_is_called_when_the_client_is_initialized_without_token( + self + ): + with tempfile.TemporaryDirectory() as tmp: + with set_temporary_config({"home_dir": tmp}): + client = Client(api_token="a") + client.save_api_token() + + client = Client() + assert client._api_token == "a" + + def test_load_local_api_token_is_called_when_the_client_is_initialized_without_token( + self + ): + with tempfile.TemporaryDirectory() as tmp: + with set_temporary_config({"home_dir": tmp}): + client = Client(api_token="a") + client.save_api_token() + + client = Client(api_token="b") + assert client._api_token == "b" + + assert Client()._api_token == "a" + + +class TestTenantAuth: + def test_login_to_tenant_requires_argument(self): + client = Client() + with pytest.raises(ValueError, match="At least one"): + client.login_to_tenant() + + def test_login_to_tenant_requires_valid_uuid(self): + client = Client() + with pytest.raises(ValueError, match="valid UUID"): + client.login_to_tenant(tenant_id="a") + + def test_login_to_client_sets_access_token(self, patch_post): + tenant_id = str(uuid.uuid4()) + post = patch_post( + { + "data": { + "tenant": [{"id": tenant_id}], + "switchTenant": { + "accessToken": "ACCESS_TOKEN", + "expiresIn": 600, + "refreshToken": "REFRESH_TOKEN", + }, + } + } + ) + client = Client() + assert client._access_token is None + assert client._refresh_token is None + client.login_to_tenant(tenant_id=tenant_id) + assert client._access_token == "ACCESS_TOKEN" + assert client._refresh_token == "REFRESH_TOKEN" + + def test_login_uses_api_token(self, patch_post): + tenant_id = str(uuid.uuid4()) + post = patch_post( + { + "data": { + "tenant": [{"id": tenant_id}], + "switchTenant": { + "accessToken": "ACCESS_TOKEN", + "expiresIn": 600, + "refreshToken": "REFRESH_TOKEN", + }, + } + } + ) + client = Client(api_token="api") + client.login_to_tenant(tenant_id=tenant_id) + assert post.call_args[1]["headers"] == dict(Authorization="Bearer api") + + def test_login_uses_api_token_when_access_token_is_set(self, patch_post): + tenant_id = str(uuid.uuid4()) + post = patch_post( + { + "data": { + "tenant": [{"id": tenant_id}], + "switchTenant": { + "accessToken": "ACCESS_TOKEN", + "expiresIn": 600, + "refreshToken": "REFRESH_TOKEN", + }, + } + } + ) + client = Client(api_token="api") + client._access_token = "access" + client.login_to_tenant(tenant_id=tenant_id) + assert client.get_auth_token() == "ACCESS_TOKEN" + assert post.call_args[1]["headers"] == dict(Authorization="Bearer api") + + def test_graphql_uses_access_token_after_login(self, patch_post): + tenant_id = str(uuid.uuid4()) + post = patch_post( + { + "data": { + "tenant": [{"id": tenant_id}], + "switchTenant": { + "accessToken": "ACCESS_TOKEN", + "expiresIn": 600, + "refreshToken": "REFRESH_TOKEN", + }, + } + } + ) + client = Client(api_token="api") + client.graphql({}) + assert client.get_auth_token() == "api" + assert post.call_args[1]["headers"] == dict(Authorization="Bearer api") + + client.login_to_tenant(tenant_id=tenant_id) + client.graphql({}) + assert client.get_auth_token() == "ACCESS_TOKEN" + assert post.call_args[1]["headers"] == dict(Authorization="Bearer ACCESS_TOKEN") + + def test_login_to_tenant_writes_tenant_and_reloads_it_when_token_is_reloaded( + self, patch_post + ): + tenant_id = str(uuid.uuid4()) + post = patch_post( + { + "data": { + "tenant": [{"id": tenant_id}], + "switchTenant": { + "accessToken": "ACCESS_TOKEN", + "expiresIn": 600, + "refreshToken": "REFRESH_TOKEN", + }, + } + } + ) + + client = Client(api_token="abc") + assert client._active_tenant_id is None + client.login_to_tenant(tenant_id=tenant_id) + client.save_api_token() + assert client._active_tenant_id == tenant_id + + # new client loads the active tenant and token + assert Client()._active_tenant_id == tenant_id + assert Client()._api_token == "abc" + + def test_login_to_client_doesnt_reload_active_tenant_when_token_isnt_loaded( + self, patch_post + ): + tenant_id = str(uuid.uuid4()) + post = patch_post( + { + "data": { + "tenant": [{"id": tenant_id}], + "switchTenant": { + "accessToken": "ACCESS_TOKEN", + "expiresIn": 600, + "refreshToken": "REFRESH_TOKEN", + }, + } + } + ) + + client = Client(api_token="abc") + assert client._active_tenant_id is None + client.login_to_tenant(tenant_id=tenant_id) + assert client._active_tenant_id == tenant_id + + # new client doesn't load the active tenant because there's no api token loaded + assert Client()._active_tenant_id is None + + def test_logout_clears_access_token_and_tenant(self, patch_post): + tenant_id = str(uuid.uuid4()) + post = patch_post( + { + "data": { + "tenant": [{"id": tenant_id}], + "switchTenant": { + "accessToken": "ACCESS_TOKEN", + "expiresIn": 600, + "refreshToken": "REFRESH_TOKEN", + }, + } + } + ) + client = Client() + client.login_to_tenant(tenant_id=tenant_id) + + assert client._access_token is not None + assert client._refresh_token is not None + assert client._active_tenant_id is not None + + client.logout_from_tenant() + + assert client._access_token is None + assert client._refresh_token is None + assert client._active_tenant_id is None + + # new client doesn't load the active tenant + assert Client()._active_tenant_id is None + + def test_refresh_token_sets_attributes(self, patch_post): + patch_post( + { + "data": { + "refreshToken": { + "accessToken": "ACCESS_TOKEN", + "expiresIn": 600, + "refreshToken": "REFRESH_TOKEN", + } + } + } + ) + client = Client() + assert client._access_token is None + assert client._refresh_token is None + assert client._access_token_expires_at < pendulum.now() + client._refresh_access_token() + assert client._access_token is "ACCESS_TOKEN" + assert client._refresh_token is "REFRESH_TOKEN" + assert client._access_token_expires_at > pendulum.now().add(seconds=599) + + def test_refresh_token_passes_access_token_as_arg(self, patch_post): + post = patch_post( + { + "data": { + "refreshToken": { + "accessToken": "ACCESS_TOKEN", + "expiresIn": 600, + "refreshToken": "REFRESH_TOKEN", + } + } + } + ) + client = Client() + client._access_token = "access" + client._refresh_access_token() + variables = json.loads(post.call_args[1]["json"]["variables"]) + assert variables["input"]["accessToken"] == "access" + + def test_refresh_token_passes_refresh_token_as_header(self, patch_post): + post = patch_post( + { + "data": { + "refreshToken": { + "accessToken": "ACCESS_TOKEN", + "expiresIn": 600, + "refreshToken": "REFRESH_TOKEN", + } + } + } + ) + client = Client() + client._refresh_token = "refresh" + client._refresh_access_token() + assert post.call_args[1]["headers"] == dict(Authorization="Bearer refresh") + + def test_get_available_tenants(self, patch_post): + tenants = [ + {"id": "a", "name": "a-name", "slug": "a-slug"}, + {"id": "b", "name": "b-name", "slug": "b-slug"}, + {"id": "c", "name": "c-name", "slug": "c-slug"}, + ] + post = patch_post({"data": {"tenant": tenants}}) + client = Client() + gql_tenants = client.get_available_tenants() + assert gql_tenants == tenants + + def test_get_auth_token_returns_api_if_access_token_not_set(self): + client = Client(api_token="api") + assert client._access_token is None + assert client.get_auth_token() == "api" + + def test_get_auth_token_returns_access_token_if_set(self): + client = Client(api_token="api") + client._access_token = "access" + assert client.get_auth_token() == "access" + + def test_get_auth_token_refreshes_if_refresh_token_and_expiration_within_30_seconds( + self, monkeypatch + ): + refresh_token = MagicMock() + monkeypatch.setattr("prefect.Client._refresh_access_token", refresh_token) + client = Client(api_token="api") + client._access_token = "access" + client._refresh_token = "refresh" + client._access_token_expires_at = pendulum.now().add(seconds=29) + client.get_auth_token() + assert refresh_token.called + + def test_get_auth_token_refreshes_if_refresh_token_and_no_expiration( + self, monkeypatch + ): + refresh_token = MagicMock() + monkeypatch.setattr("prefect.Client._refresh_access_token", refresh_token) + client = Client(api_token="api") + client._access_token = "access" + client._refresh_token = "refresh" + client._access_token_expires_at = None + client.get_auth_token() + assert refresh_token.called + + def test_get_auth_token_doesnt_refreshe_if_refresh_token_and_future_expiration( + self, monkeypatch + ): + refresh_token = MagicMock() + monkeypatch.setattr("prefect.Client._refresh_access_token", refresh_token) + client = Client(api_token="api") + client._access_token = "access" + client._refresh_token = "refresh" + client._access_token_expires_at = pendulum.now().add(minutes=10) + assert client.get_auth_token() == "access" + refresh_token.assert_not_called() + + +class TestPassingHeadersAndTokens: + def test_headers_are_passed_to_get(self, monkeypatch): + get = MagicMock() + session = MagicMock() + session.return_value.get = get + monkeypatch.setattr("requests.Session", session) + with set_temporary_config( + {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} + ): + client = Client() + client.get("/foo/bar", headers={"x": "y", "Authorization": "z"}) + assert get.called + assert get.call_args[1]["headers"] == { + "x": "y", + "Authorization": "Bearer secret_token", + } + + def test_headers_are_passed_to_post(self, monkeypatch): + post = MagicMock() + session = MagicMock() + session.return_value.post = post + monkeypatch.setattr("requests.Session", session) + with set_temporary_config( + {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} + ): + client = Client() + client.post("/foo/bar", headers={"x": "y", "Authorization": "z"}) + assert post.called + assert post.call_args[1]["headers"] == { + "x": "y", + "Authorization": "Bearer secret_token", + } + + def test_headers_are_passed_to_graphql(self, monkeypatch): + post = MagicMock() + session = MagicMock() + session.return_value.post = post + monkeypatch.setattr("requests.Session", session) + with set_temporary_config( + {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} + ): + client = Client() + client.graphql("query {}", headers={"x": "y", "Authorization": "z"}) + assert post.called + assert post.call_args[1]["headers"] == { + "x": "y", + "Authorization": "Bearer secret_token", + } + + def test_tokens_are_passed_to_get(self, monkeypatch): + get = MagicMock() + session = MagicMock() + session.return_value.get = get + monkeypatch.setattr("requests.Session", session) + with set_temporary_config({"cloud.graphql": "http://my-cloud.foo"}): + client = Client() + client.get("/foo/bar", token="secret_token") + assert get.called + assert get.call_args[1]["headers"] == {"Authorization": "Bearer secret_token"} + + def test_tokens_are_passed_to_post(self, monkeypatch): + post = MagicMock() + session = MagicMock() + session.return_value.post = post + monkeypatch.setattr("requests.Session", session) + with set_temporary_config({"cloud.graphql": "http://my-cloud.foo"}): + client = Client() + client.post("/foo/bar", token="secret_token") + assert post.called + assert post.call_args[1]["headers"] == {"Authorization": "Bearer secret_token"} + + def test_tokens_are_passed_to_graphql(self, monkeypatch): + post = MagicMock() + session = MagicMock() + session.return_value.post = post + monkeypatch.setattr("requests.Session", session) + with set_temporary_config({"cloud.graphql": "http://my-cloud.foo"}): + client = Client() + client.graphql("query {}", token="secret_token") + assert post.called + assert post.call_args[1]["headers"] == {"Authorization": "Bearer secret_token"} diff --git a/tests/client/test_secrets.py b/tests/client/test_secrets.py index 17dad89e9d19..b5080983d8ff 100644 --- a/tests/client/test_secrets.py +++ b/tests/client/test_secrets.py @@ -6,7 +6,7 @@ import prefect from prefect.client import Secret from prefect.utilities.configuration import set_temporary_config -from prefect.utilities.exceptions import AuthorizationError +from prefect.utilities.exceptions import AuthorizationError, ClientError ################################# ##### Secret Tests @@ -36,14 +36,11 @@ def test_secret_value_pulled_from_context(): def test_secret_value_depends_on_use_local_secrets(monkeypatch): secret = Secret(name="test") - monkeypatch.setattr( - "prefect.client.secrets.os.path.exists", MagicMock(return_value=False) - ) with set_temporary_config( {"cloud.use_local_secrets": False, "cloud.auth_token": None} ): with prefect.context(secrets=dict(test=42)): - with pytest.raises(AuthorizationError, match="Client.login"): + with pytest.raises(ClientError): secret.get() diff --git a/tests/conftest.py b/tests/conftest.py index 086cade03ef6..5e18972dfbf6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,26 @@ +import os +import tempfile import sys +from unittest.mock import MagicMock import pytest from distributed import Client import prefect from prefect.engine.executors import DaskExecutor, LocalExecutor, SynchronousExecutor -from prefect.utilities import debug +from prefect.utilities import debug, configuration + + +@pytest.fixture(autouse=True) +def prefect_home_dir(): + """ + Sets a temporary home directory + """ + with tempfile.TemporaryDirectory() as tmp: + tmp = os.path.join(tmp, ".prefect") + os.makedirs(tmp) + with configuration.set_temporary_config({"home_dir": tmp}): + yield tmp # ---------------- @@ -64,3 +79,23 @@ def executor(request, _switch): or with some subset of executors that you want to use. """ return _switch(request.param) + + +@pytest.fixture +def patch_post(monkeypatch): + """ + Patches `prefect.client.Client.post()` (and `graphql()`) to return the specified response. + + The return value of the fixture is a function that is called on the response to patch it. + + Typically, the response will contain up to two keys, `data` and `errors`. + """ + + def patch(response): + post = MagicMock(return_value=MagicMock(json=MagicMock(return_value=response))) + session = MagicMock() + session.return_value.post = post + monkeypatch.setattr("requests.Session", session) + return post + + return patch From d227fe5e62b210750dac2be170cf61acb6ee0b1e Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Thu, 29 Aug 2019 15:30:48 -0400 Subject: [PATCH 2/4] Update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 318e091def97..d31c78aef120 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ These changes are available in the [master branch](https://github.com/PrefectHQ/ - Add `task_slug`, `flow_id`, and `flow_run_id` to context - [#1405](https://github.com/PrefectHQ/prefect/pull/1405) - Support persistent `scheduled_start_time` for scheduled flow runs when run locally with `flow.run()` - [#1418](https://github.com/PrefectHQ/prefect/pull/1418) - Add `task_args` to `Task.map` - [#1390](https://github.com/PrefectHQ/prefect/issues/1390) +- Add auth flows for `USER`-scoped Cloud API tokens - [#1423](https://github.com/PrefectHQ/prefect/pull/1423) ### Task Library From 05a044bdaf3d382604288f9a07fd21df68ae7ce2 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Thu, 29 Aug 2019 16:20:10 -0400 Subject: [PATCH 3/4] Update docs --- docs/cloud/agent/kubernetes.md | 4 +- docs/cloud/agent/local.md | 4 +- docs/cloud/cloud_concepts/api.md | 110 +++++++++++++++++++++++++++++ docs/cloud/cloud_concepts/auth.md | 32 --------- docs/cloud/cloud_concepts/debug.md | 2 +- 5 files changed, 115 insertions(+), 37 deletions(-) create mode 100644 docs/cloud/cloud_concepts/api.md delete mode 100644 docs/cloud/cloud_concepts/auth.md diff --git a/docs/cloud/agent/kubernetes.md b/docs/cloud/agent/kubernetes.md index 1c6616f890e2..45ae7bd408ac 100644 --- a/docs/cloud/agent/kubernetes.md +++ b/docs/cloud/agent/kubernetes.md @@ -45,8 +45,8 @@ The Kubernetes Agent can be started either through the Prefect CLI or by importi There are a few ways in which you can specify an `AGENT` API token: - command argument `prefect agent start kubernetes -t MY_TOKEN` -- environment variable `export PREFECT__CLOUD__AGENT__API_TOKEN=MY_TOKEN` -- token will be used from `prefect.config.cloud.api_token` if not provided from one of the two previous methods +- environment variable `export PREFECT__CLOUD__AGENT__AUTH_TOKEN=MY_TOKEN` +- token will be used from `prefect.config.cloud.auth_token` if not provided from one of the two previous methods ::: diff --git a/docs/cloud/agent/local.md b/docs/cloud/agent/local.md index c195e3decf64..de03b09ae723 100644 --- a/docs/cloud/agent/local.md +++ b/docs/cloud/agent/local.md @@ -32,8 +32,8 @@ The Local Agent can be started either through the Prefect CLI or by importing th There are a few ways in which you can specify an `AGENT` API token: - command argument `prefect agent start -t MY_TOKEN` -- environment variable `export PREFECT__CLOUD__AGENT__API_TOKEN=MY_TOKEN` -- token will be used from `prefect.config.cloud.api_token` if not provided from one of the two previous methods +- environment variable `export PREFECT__CLOUD__AGENT__AUTH_TOKEN=MY_TOKEN` +- token will be used from `prefect.config.cloud.auth_token` if not provided from one of the two previous methods ::: diff --git a/docs/cloud/cloud_concepts/api.md b/docs/cloud/cloud_concepts/api.md new file mode 100644 index 000000000000..6e0b5c2b8a2e --- /dev/null +++ b/docs/cloud/cloud_concepts/api.md @@ -0,0 +1,110 @@ +# Prefect Cloud API + +Prefect Cloud exposes a powerful GraphQL API for interacting with the platform. There are a variety of ways you can access the API. + +# Authentication + +In order to interact with Cloud from your local machine, you'll need to generate an API token. + +To generate an API token, use the Cloud UI or the following GraphQL call (from an already authenticated client!): + +```graphql +mutation { + createAPIToken(input: { name: "My API token", role: USER }) { + token + } +} +``` + +## API Token Scopes + +Prefect Cloud can generate API tokens with three different scopes. + +### `USER` + +`USER`-scoped API tokens function as personal access tokens. These tokens have very few permissions on their own, but can be used to authenticate with the Cloud API. Once authenticated, `USER` tokens can be used to generate short-lived JWT auth tokens for any tenant the user belongs to. These auth tokens inherit any permissions the user has in that tenant, allowing full API access. The Client manages the process of provisioning and refreshing these tokens. + +### `TENANT` + +`TENANT`-scoped API tokens are used for long-lived programmatic access to a specific Cloud tenant. Unlike `USER` tokens, which can adopt any tenant membership the user has, `TENANT` tokens are fixed to a specific membership in a specific tenant. They adopt whatever permissions the user has in the tenant. + +### `AGENT` + +`AGENT`-scoped API tokens are used for processes like the Prefect Agent, which require the ability to execute flows on behalf of a tenant. Unlike the other token types, `AGENT` tokens are not scoped to a particular user. Consequently, they can only be generated by tenant admins. + +# Python Client + +## About the Client + +Prefect Core includes a Python client for Prefect Cloud. The Python client was designed for both interactive and programmatic use, and includes convenience methods for transparently managing authentication when used with `USER`-scoped tokens. + +## Getting started + +For interactive use, the most common way to use the Cloud Client is to generate a `USER`-scoped token and provide it to the client. After doing so, users can save the token so it persists across all Python sessions: + +```python +import prefect +client = prefect.Client(api_token="YOUR_USER_TOKEN") +client.save_api_token() +``` + +Now, starting a client in another session will automatically reload the token: + +```python +client = prefect.Client() +assert client._api_token == "YOUR_USER_TOKEN" # True +``` + +Note that a token can be provided by environment variable (`PREFECT__CLOUD__AUTH_TOKEN`) or in your Prefect config (under `cloud.auth_token`). + +:::tip Using `USER` tokens +The steps shown here were designed to be used with `USER`-scoped tokens. They will work with `TENANT` scoped tokens as well, but unexpected errors could result. +::: + +Once provisioned with a `USER` token, the Cloud Client can query for available tenants and login to those tenants. In order to query for tenants, call: + +```python +client.get_available_tenants() +``` + +This will print the id, name, and slug of all the tenants the user can login to. + +```python +client.login_to_tenant(tenant_slug='a-tenant-slug') +# OR +client.login_to_tenant(tenant_id='A_TENANT_ID') +``` + +Both of these calls persist the active tenant in local storage, so you won't have to login again until you're ready to switch tenants. + +Once logged in, you can make any GraphQL query against the Cloud API: + +```python +client.graphql( + { + 'query': { + 'flow': { + 'id' + } + } + } +) +``` + +(Note that this illustrates how Prefect can parse Python structures to construct GraphQL query strings!) + +Finally, you can logout: + +```python +client.logout_from_tenant() +``` + +# GraphQL + +The Cloud API also supports direct GraphQL access. While you can still use `USER`-scoped tokens to access and log in to tenants, you will need to manage the short-lived auth and refresh tokens yourself. Therefore, we recommend using the Python client for `USER`-scoped access. + +For `TENANT`-scoped tokens, simply include the token as the authorization header of your GraphQL requests: + +```json +{ "authorization": "Bearer YOUR_TOKEN" } +``` diff --git a/docs/cloud/cloud_concepts/auth.md b/docs/cloud/cloud_concepts/auth.md deleted file mode 100644 index 80677ba3c191..000000000000 --- a/docs/cloud/cloud_concepts/auth.md +++ /dev/null @@ -1,32 +0,0 @@ -# Authentication - -In order to use the Cloud APIs from your local machine, you'll need to generate an API token. - -To generate an API token, use the Cloud UI or the following GraphQL call: - -```graphql -mutation { - createAPIToken(input: { name: "My API token", role: USER }) { - token - } -} -``` - -## Prefect Core client - -This token can either be added to your Prefect [user configuration file](../../guide/core_concepts/configuration.html): - -``` -[cloud] -auth_token = "" -``` - -or assigned to the environment variable `PREFECT__CLOUD__AUTH_TOKEN`. - -## GraphQL - -You can also use your API token to communicate directly with the Cloud GraphQL API, including the GraphQL Playground. Simply include the token in your HTTP headers like this: - -```json -{ "Authorization": "Bearer " } -``` diff --git a/docs/cloud/cloud_concepts/debug.md b/docs/cloud/cloud_concepts/debug.md index 6340b36e5533..17607308cf06 100644 --- a/docs/cloud/cloud_concepts/debug.md +++ b/docs/cloud/cloud_concepts/debug.md @@ -11,7 +11,7 @@ The most likely culprit when a flow is stuck in scheduled is agent misconfigurat 2. Check that the API token given to the agent is scoped to the same tenant as your flow ``` -$ export PREFECT__CLOUD__API_TOKEN=YOUR_TOKEN +$ export PREFECT__CLOUD__AUTH_TOKEN=YOUR_TOKEN $ prefect get flows # if you do not see your flow then there is a tenant mismatch ``` From 0f16077787b13dee1569835b2eb7f1b55bcdfc68 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Thu, 29 Aug 2019 16:56:22 -0400 Subject: [PATCH 4/4] Update api.md --- docs/cloud/cloud_concepts/api.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/cloud/cloud_concepts/api.md b/docs/cloud/cloud_concepts/api.md index 6e0b5c2b8a2e..50796aa6dff5 100644 --- a/docs/cloud/cloud_concepts/api.md +++ b/docs/cloud/cloud_concepts/api.md @@ -28,9 +28,9 @@ Prefect Cloud can generate API tokens with three different scopes. `TENANT`-scoped API tokens are used for long-lived programmatic access to a specific Cloud tenant. Unlike `USER` tokens, which can adopt any tenant membership the user has, `TENANT` tokens are fixed to a specific membership in a specific tenant. They adopt whatever permissions the user has in the tenant. -### `AGENT` +### `RUNNER` -`AGENT`-scoped API tokens are used for processes like the Prefect Agent, which require the ability to execute flows on behalf of a tenant. Unlike the other token types, `AGENT` tokens are not scoped to a particular user. Consequently, they can only be generated by tenant admins. +`RUNNER`-scoped API tokens are used for processes like the Prefect Agent, which require the ability to execute flows on behalf of a tenant. Unlike the other token types, `RUNNER` tokens are not scoped to a particular user. Consequently, they can only be generated by tenant admins. # Python Client