Skip to content

Commit

Permalink
Fixes issues with credentials getting expired (#784)
Browse files Browse the repository at this point in the history
* dynamically update creds

* credential refresh now works

* linting fix

* moved value to a constant
  • Loading branch information
AbhinavTuli committed Apr 22, 2021
1 parent 4168264 commit 0a435a8
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 14 deletions.
4 changes: 2 additions & 2 deletions hub/api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ def delete(self):
exist_meta = fs.exists(posixpath.join(path, defaults.META_FILE))
if exist_meta:
fs.rm(path, recursive=True)
if self.username is not None:
if self.username:
HubControlClient().delete_dataset_entry(
self.username, self.dataset_name
)
Expand Down Expand Up @@ -863,7 +863,7 @@ def close(self):
self._update_dataset_state()

def _update_dataset_state(self):
if self.username is not None:
if self.username:
HubControlClient().update_dataset_state(
self.username, self.dataset_name, "UPLOADED"
)
Expand Down
7 changes: 3 additions & 4 deletions hub/client/hub_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from hub.client.base import HubHttpClient
from hub.client.auth import AuthClient
from pathlib import Path

from hub import defaults
from hub.exceptions import NotFoundException
from hub.log import logger
import traceback
Expand Down Expand Up @@ -49,6 +49,7 @@ def get_credentials(self):
"GET",
config.GET_CREDENTIALS_SUFFIX,
endpoint=config.HUB_REST_ENDPOINT,
params={"duration": defaults.CRED_EXPIRATION},
).json()

details = {
Expand All @@ -66,15 +67,13 @@ def get_credentials(self):
return details

def get_config(self, reset=False):

if not os.path.isfile(config.STORE_CONFIG_PATH) or self.auth_header is None:
self.get_credentials()

with open(config.STORE_CONFIG_PATH) as file:
details = file.readlines()
details = json.loads("".join(details))

if float(details["expiration"]) < time.time() - 36000 or reset:
if float(details["expiration"]) < time.time() or reset:
details = self.get_credentials()
return details

Expand Down
1 change: 1 addition & 0 deletions hub/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
AZURE_HOST_SUFFIX = "blob.core.windows.net"
META_FILE = "meta.json"
VERSION_INFO = "version.pkl"
CRED_EXPIRATION = 36000 # in seconds
4 changes: 3 additions & 1 deletion hub/store/s3_file_system_replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@


class S3FileSystemReplacement(S3FileSystem):
def __init__(self, *args, **kwargs):
def __init__(self, *args, expiration=None, **kwargs):
super().__init__(*args, **kwargs)
self._args = args
self._kwargs = kwargs
self.expiration = expiration

def get_mapper(self, root: str, check=False, create=False):
root = "s3://" + root
Expand All @@ -29,4 +30,5 @@ def get_mapper(self, root: str, check=False, create=False):
aws_session_token=self._kwargs.get("token"),
aws_region=aws_region,
endpoint_url=endpoint_url,
expiration=self.expiration,
)
46 changes: 39 additions & 7 deletions hub/store/s3_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from hub.exceptions import S3Exception
from hub.log import logger
from hub.client.hub_control import HubControlClient
import time


class S3Storage(MutableMapping):
Expand All @@ -28,6 +30,7 @@ def __init__(
parallel=25,
endpoint_url=None,
aws_region=None,
expiration=None,
):
self.s3fs = s3fs
self.root = {}
Expand All @@ -36,6 +39,7 @@ def __init__(
self.parallel = parallel
self.aws_region = aws_region
self.endpoint_url = endpoint_url
self.expiration = expiration
self.bucket = url.split("/")[2]
self.path = "/".join(url.split("/")[3:])
if self.bucket == "s3:":
Expand All @@ -45,7 +49,7 @@ def __init__(
self.bucketpath = posixpath.join(self.bucket, self.path)
self.protocol = "object"

client_config = botocore.config.Config(
self.client_config = botocore.config.Config(
max_pool_connections=parallel,
)

Expand All @@ -54,22 +58,46 @@ def __init__(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
config=client_config,
endpoint_url=endpoint_url,
region_name=aws_region,
config=self.client_config,
endpoint_url=self.endpoint_url,
region_name=self.aws_region,
)

self.resource = boto3.resource(
"s3",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
config=client_config,
endpoint_url=endpoint_url,
region_name=aws_region,
config=self.client_config,
endpoint_url=self.endpoint_url,
region_name=self.aws_region,
)

def check_update_creds(self):
if self.expiration and float(self.expiration) < time.time():
details = HubControlClient().get_credentials()
self.expiration = details["expiration"]
self.client = boto3.client(
"s3",
aws_access_key_id=details["access_key"],
aws_secret_access_key=details["secret_key"],
aws_session_token=details["session_token"],
config=self.client_config,
endpoint_url=self.endpoint_url,
region_name=self.aws_region,
)
self.resource = boto3.resource(
"s3",
aws_access_key_id=details["access_key"],
aws_secret_access_key=details["secret_key"],
aws_session_token=details["session_token"],
config=self.client_config,
endpoint_url=self.endpoint_url,
region_name=self.aws_region,
)

def __setitem__(self, path, content):
self.check_update_creds()
try:
path = posixpath.join(self.path, path)
content = bytearray(memoryview(content))
Expand All @@ -86,6 +114,7 @@ def __setitem__(self, path, content):
raise S3Exception(err)

def __getitem__(self, path):
self.check_update_creds()
try:
path = posixpath.join(self.path, path)
resp = self.client.get_object(
Expand All @@ -104,6 +133,7 @@ def __getitem__(self, path):
raise S3Exception(err)

def __delitem__(self, path):
self.check_update_creds()
try:
path = posixpath.join(self.bucketpath, path)
self.s3fs.rm(path, recursive=True)
Expand All @@ -112,8 +142,10 @@ def __delitem__(self, path):
raise S3Exception(err)

def __len__(self):
self.check_update_creds()
return len(self.s3fs.ls(self.bucketpath, detail=False, refresh=True))

def __iter__(self):
self.check_update_creds()
items = self.s3fs.ls(self.bucketpath, detail=False, refresh=True)
yield from [item[len(self.bucketpath) + 1 :] for item in items]
1 change: 1 addition & 0 deletions hub/store/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def get_fs_and_path(
# TOOD check if url is username/dataset:version
url, creds = _connect(url, public=public)
fs = S3FileSystemReplacement(
expiration=creds["expiration"],
key=creds["access_key"],
secret=creds["secret_key"],
token=creds["session_token"],
Expand Down

0 comments on commit 0a435a8

Please sign in to comment.