Skip to content

Commit

Permalink
feat(datasets): allow multipart uploads for large datasets (#384)
Browse files Browse the repository at this point in the history
This attempts to fall back to a multipart upload strategy with presigned
URLs in the event that a dataset is larger than 500MB
  • Loading branch information
cwetherill-ps committed Apr 8, 2022
1 parent c0e0625 commit 53d4c87
Showing 1 changed file with 108 additions and 10 deletions.
118 changes: 108 additions & 10 deletions gradient/commands/datasets.py
Expand Up @@ -12,6 +12,9 @@
import Queue as queue
from xml.etree import ElementTree
from urllib.parse import urlparse
from ..api_sdk.clients import http_client
from ..api_sdk.config import config
from ..cli_constants import CLI_PS_CLIENT_NAME

import halo
import requests
Expand Down Expand Up @@ -557,24 +560,114 @@ def update_status():

class PutDatasetFilesCommand(BaseDatasetFilesCommand):

@classmethod
def _put(cls, path, url, content_type):
# @classmethod
def _put(self, path, url, content_type, dataset_version_id=None, key=None):
size = os.path.getsize(path)
with requests.Session() as session:
headers = {'Content-Type': content_type}

try:
if size > 0:
if size <= 0:
headers.update({'Content-Size': '0'})
r = session.put(url, data='', headers=headers, timeout=5)
# for files under half a GB
elif size <= (10e8) / 2:
with open(path, 'rb') as f:
r = session.put(
url, data=f, headers=headers, timeout=5)
# # for chonky files, use a multipart upload
else:
headers.update({'Content-Size': '0'})
r = session.put(url, data='', headers=headers, timeout=5)

cls.validate_s3_response(r)
# Chunks need to be at least 5MB or AWS throws an
# EntityTooSmall error; we'll arbitrarily choose a
# 15MB chunksize
#
# Note also that AWS limits the max number of chunkc
# in a multipart upload to 10000, so this setting
# currently enforces a hard limit on 150GB per file.
#
# We can dynamically assign a larger part size if needed,
# but for the majority of use cases we should be fine
# as-is
part_minsize = int(15e6)
dataset_id, _, version = dataset_version_id.partition(":")
mpu_url = f'/datasets/{dataset_id}/versions/{version}/s3/preSignedUrls'

api_client = http_client.API(
api_url=config.CONFIG_HOST,
api_key=self.api_key,
ps_client_name=CLI_PS_CLIENT_NAME
)

mpu_create_res = api_client.post(
url=mpu_url,
json={
'datasetId': dataset_id,
'version': version,
'calls': [{
'method': 'createMultipartUpload',
'params': {'Key': key}
}]
}
)
mpu_data = json.loads(mpu_create_res.text)[0]['url']

parts = []
with open(path, 'rb') as f:
# we +2 the number of parts since we're doing floor
# division, which will cut off any trailing part
# less than the part_minsize, AND we want to 1-index
# our range to match what AWS expects for part
# numbers
for part in range(1, (size // part_minsize) + 2):
presigned_url_res = api_client.post(
url=mpu_url,
json={
'datasetId': dataset_id,
'version': version,
'calls': [{
'method': 'uploadPart',
'params': {
'Key': key,
'UploadId': mpu_data['UploadId'],
'PartNumber': part
}
}]
}
)

presigned_url = json.loads(
presigned_url_res.text
)[0]['url']

chunk = f.read(part_minsize)
part_res = session.put(
presigned_url,
data=chunk,
timeout=5)
etag = part_res.headers['ETag'].replace('"', '')
parts.append({'ETag': etag, 'PartNumber': part})

r = api_client.post(
url=mpu_url,
json={
'datasetId': dataset_id,
'version': version,
'calls': [{
'method': 'completeMultipartUpload',
'params': {
'Key': key,
'UploadId': mpu_data['UploadId'],
'MultipartUpload': {'Parts': parts}
}
}]
}
)

self.validate_s3_response(r)
except requests.exceptions.ConnectionError as e:
return cls.report_connection_error(e)
return self.report_connection_error(e)
except Exception as e:
return e

@staticmethod
def _list_files(source_path):
Expand All @@ -599,8 +692,13 @@ def _sign_and_put(self, dataset_version_id, pool, results, update_status):

for pre_signed, result in zip(pre_signeds, results):
update_status()
pool.put(self._put, url=pre_signed.url,
path=result['path'], content_type=result['mimetype'])
pool.put(
self._put,
url=pre_signed.url,
path=result['path'],
content_type=result['mimetype'],
dataset_version_id=dataset_version_id,
key=result['key'])

def execute(self, dataset_version_id, source_paths, target_path):
self.assert_supported(dataset_version_id)
Expand Down

0 comments on commit 53d4c87

Please sign in to comment.