Skip to content

Commit

Permalink
Merge pull request #1690 from PrefectHQ/client-header
Browse files Browse the repository at this point in the history
Add version header to all client requests
  • Loading branch information
cicdw committed Nov 1, 2019
2 parents be58009 + 64b4ac8 commit b6d7668
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ These changes are available in the [master branch](https://github.com/PrefectHQ/

- Add a `save`/`load` interface to Flows - [#1685](https://github.com/PrefectHQ/prefect/pull/1685)
- Add option to specify `aws_session_token` for the `FargateTaskEnvironment` - [#1688](https://github.com/PrefectHQ/prefect/pull/1688)
- Add an informative version header to all Cloud client requests - [#1690](https://github.com/PrefectHQ/prefect/pull/1690)

### Task Library

Expand Down
1 change: 1 addition & 0 deletions src/prefect/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def _request(
headers = headers or {}
if token:
headers["Authorization"] = "Bearer {}".format(token)
headers["X-PREFECT-CORE-VERSION"] = str(prefect.__version__)

session = requests.Session()
retries = Retry(
Expand Down
30 changes: 30 additions & 0 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,36 @@ def test_client_posts_to_api_server(patch_post):
assert post.call_args[0][0] == "http://my-cloud.foo/foo/bar"


def test_version_header(monkeypatch):
get = MagicMock()
session = MagicMock()
session.return_value.get = get
monkeypatch.setattr("requests.Session", session)
with set_temporary_config(
{"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"}
):
client = Client()
client.get("/foo/bar")
assert get.call_args[1]["headers"]["X-PREFECT-CORE-VERSION"] == str(
prefect.__version__
)


def test_version_header_cant_be_overridden(monkeypatch):
get = MagicMock()
session = MagicMock()
session.return_value.get = get
monkeypatch.setattr("requests.Session", session)
with set_temporary_config(
{"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"}
):
client = Client()
client.get("/foo/bar", headers={"X-PREFECT-CORE-VERSION": "-1",})
assert get.call_args[1]["headers"]["X-PREFECT-CORE-VERSION"] == str(
prefect.__version__
)


def test_client_posts_graphql_to_api_server(patch_post):
post = patch_post(dict(data=dict(success=True)))

Expand Down
43 changes: 35 additions & 8 deletions tests/client/test_client_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,10 @@ def test_login_uses_api_token(self, patch_post):
)
client = Client(api_token="api")
client.login_to_tenant(tenant_id=tenant_id)
assert post.call_args[1]["headers"] == dict(Authorization="Bearer api")
assert post.call_args[1]["headers"] == {
"Authorization": "Bearer api",
"X-PREFECT-CORE-VERSION": str(prefect.__version__),
}

def test_login_uses_api_token_when_access_token_is_set(self, patch_post):
tenant_id = str(uuid.uuid4())
Expand All @@ -200,7 +203,10 @@ def test_login_uses_api_token_when_access_token_is_set(self, patch_post):
client._access_token = "access"
client.login_to_tenant(tenant_id=tenant_id)
assert client.get_auth_token() == "ACCESS_TOKEN"
assert post.call_args[1]["headers"] == dict(Authorization="Bearer api")
assert post.call_args[1]["headers"] == {
"Authorization": "Bearer api",
"X-PREFECT-CORE-VERSION": str(prefect.__version__),
}

def test_graphql_uses_access_token_after_login(self, patch_post):
tenant_id = str(uuid.uuid4())
Expand All @@ -219,12 +225,18 @@ def test_graphql_uses_access_token_after_login(self, patch_post):
client = Client(api_token="api")
client.graphql({})
assert client.get_auth_token() == "api"
assert post.call_args[1]["headers"] == dict(Authorization="Bearer api")
assert post.call_args[1]["headers"] == {
"Authorization": "Bearer api",
"X-PREFECT-CORE-VERSION": str(prefect.__version__),
}

client.login_to_tenant(tenant_id=tenant_id)
client.graphql({})
assert client.get_auth_token() == "ACCESS_TOKEN"
assert post.call_args[1]["headers"] == dict(Authorization="Bearer ACCESS_TOKEN")
assert post.call_args[1]["headers"] == {
"Authorization": "Bearer ACCESS_TOKEN",
"X-PREFECT-CORE-VERSION": str(prefect.__version__),
}

def test_login_to_tenant_writes_tenant_and_reloads_it_when_token_is_reloaded(
self, patch_post
Expand Down Expand Up @@ -364,7 +376,10 @@ def test_refresh_token_passes_refresh_token_as_header(self, patch_post):
client = Client()
client._refresh_token = "refresh"
client._refresh_access_token()
assert post.call_args[1]["headers"] == dict(Authorization="Bearer refresh")
assert post.call_args[1]["headers"] == {
"Authorization": "Bearer refresh",
"X-PREFECT-CORE-VERSION": str(prefect.__version__),
}

def test_get_available_tenants(self, patch_post):
tenants = [
Expand Down Expand Up @@ -466,6 +481,7 @@ def test_headers_are_passed_to_get(self, monkeypatch):
assert get.call_args[1]["headers"] == {
"x": "y",
"Authorization": "Bearer secret_token",
"X-PREFECT-CORE-VERSION": str(prefect.__version__),
}

def test_headers_are_passed_to_post(self, monkeypatch):
Expand All @@ -482,6 +498,7 @@ def test_headers_are_passed_to_post(self, monkeypatch):
assert post.call_args[1]["headers"] == {
"x": "y",
"Authorization": "Bearer secret_token",
"X-PREFECT-CORE-VERSION": str(prefect.__version__),
}

def test_headers_are_passed_to_graphql(self, monkeypatch):
Expand All @@ -498,6 +515,7 @@ def test_headers_are_passed_to_graphql(self, monkeypatch):
assert post.call_args[1]["headers"] == {
"x": "y",
"Authorization": "Bearer secret_token",
"X-PREFECT-CORE-VERSION": str(prefect.__version__),
}

def test_tokens_are_passed_to_get(self, monkeypatch):
Expand All @@ -509,7 +527,10 @@ def test_tokens_are_passed_to_get(self, monkeypatch):
client = Client()
client.get("/foo/bar", token="secret_token")
assert get.called
assert get.call_args[1]["headers"] == {"Authorization": "Bearer secret_token"}
assert get.call_args[1]["headers"] == {
"Authorization": "Bearer secret_token",
"X-PREFECT-CORE-VERSION": str(prefect.__version__),
}

def test_tokens_are_passed_to_post(self, monkeypatch):
post = MagicMock()
Expand All @@ -520,7 +541,10 @@ def test_tokens_are_passed_to_post(self, monkeypatch):
client = Client()
client.post("/foo/bar", token="secret_token")
assert post.called
assert post.call_args[1]["headers"] == {"Authorization": "Bearer secret_token"}
assert post.call_args[1]["headers"] == {
"Authorization": "Bearer secret_token",
"X-PREFECT-CORE-VERSION": str(prefect.__version__),
}

def test_tokens_are_passed_to_graphql(self, monkeypatch):
post = MagicMock()
Expand All @@ -531,4 +555,7 @@ def test_tokens_are_passed_to_graphql(self, monkeypatch):
client = Client()
client.graphql("query {}", token="secret_token")
assert post.called
assert post.call_args[1]["headers"] == {"Authorization": "Bearer secret_token"}
assert post.call_args[1]["headers"] == {
"Authorization": "Bearer secret_token",
"X-PREFECT-CORE-VERSION": str(prefect.__version__),
}

0 comments on commit b6d7668

Please sign in to comment.