Skip to content

Commit

Permalink
Refresh session token when the jwt expires. (#446)
Browse files Browse the repository at this point in the history
  • Loading branch information
DailyDreaming committed Nov 6, 2019
1 parent 5c61f29 commit aa4cc95
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 4 deletions.
17 changes: 13 additions & 4 deletions hca/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def _request(self, req_args, url=None, stream=False, headers=None):
else:
session = self.client.get_session()

# TODO: (akislyuk) if using service account credentials, use manual refresh here
json_input = body if self.body_props else None
headers = headers or {}
headers.update({k: v for k, v in req_args.items() if self.parameters.get(k, {}).get('in') == 'header'})
Expand Down Expand Up @@ -447,14 +446,24 @@ def _get_jwt_from_service_account_credentials(self):
algorithm='RS256').decode()
return signed_jwt, exp

def expired_token(self):
"""Return True if we have an active session containing an expired (or nearly expired) token."""
ten_second_buffer = 10
if self._authenticated_session:
token_expiration = self._authenticated_session.token['expires_at']
if token_expiration:
if token_expiration <= time.time() + ten_second_buffer:
return True
return False

def get_authenticated_session(self):
if self._authenticated_session is None:
if self._authenticated_session is None or self.expired_token():
oauth2_client_data = self.application_secrets["installed"]
if 'GOOGLE_APPLICATION_CREDENTIALS' in os.environ:
token, expires_at = self._get_jwt_from_service_account_credentials()
# TODO: (akislyuk) figure out the right strategy for persisting the service account oauth2 token
self._authenticated_session = OAuth2Session(client_id=oauth2_client_data["client_id"],
token=dict(access_token=token),
token=dict(access_token=token,
expires_at=expires_at),
**self._session_kwargs)
else:
if "oauth2_token" not in self.config:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@
from unittest.mock import mock_open


import time
from hca.dss import DSSClient


class TestTokenDSSClient(DSSClient):
"""Mocked client that always expires request tokens within 1 second."""
token_expiration = 1

def __init__(self, *args, **kwargs):
super(TestTokenDSSClient, self).__init__(*args, **kwargs)


class TestSwaggerClient(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -116,6 +128,55 @@ def test_get_swagger_spec_local_config(self):
self.assertFalse(mock_get.called)
self.assertFalse(mock_atomic_write.called)

def test_swagger_client_refresh(self):
"""Instantiates a modified DSS client that only makes 1 second expiration tokens, forcing it to refresh."""
dss = TestTokenDSSClient(swagger_url='https://dss.dev.data.humancellatlas.org/v1/swagger.json')
assert dss._authenticated_session is None

# we use collections to test because it's an authenticated endpoint
r = dss.get_collections()
assert 'collections' in r
token_one = dss._authenticated_session.token['access_token']
expires_at = dss._authenticated_session.token['expires_at'] - time.time()
assert expires_at < 1

time.sleep(2) # wait out the 1 second expiration token

r = dss.get_collections()
assert 'collections' in r
token_two = dss._authenticated_session.token['access_token']
expires_at = dss._authenticated_session.token['expires_at'] - time.time()
assert expires_at < 1

assert token_one != token_two # make sure it requested with two different tokens

def test_swagger_client_no_refresh(self):
"""
Instantiates the normal DSSClient with a 3600 second expiration token so that we can check
that it successfully uses the same token for both requests.
"""
dss = DSSClient(swagger_url='https://dss.dev.data.humancellatlas.org/v1/swagger.json')
assert dss._authenticated_session is None

# we use collections to test because it's an authenticated endpoint
r = dss.get_collections()
assert 'collections' in r
token_one = dss._authenticated_session.token['access_token']
expires_at = dss._authenticated_session.token['expires_at'] - time.time()
assert expires_at < 3600
assert expires_at > 3590

time.sleep(2)

r = dss.get_collections()
assert 'collections' in r
token_two = dss._authenticated_session.token['access_token']
expires_at = dss._authenticated_session.token['expires_at'] - time.time()
assert expires_at < 3600
assert expires_at > 3590

assert token_one == token_two # we used one long-lived token for both requests


if __name__ == "__main__":
unittest.main()

0 comments on commit aa4cc95

Please sign in to comment.