Skip to content

feat(datasets): allow multipart uploads for large datasets #384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 108 additions & 10 deletions gradient/commands/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import Queue as queue
from xml.etree import ElementTree
from urllib.parse import urlparse
from ..api_sdk.clients import http_client
from ..api_sdk.config import config
from ..cli_constants import CLI_PS_CLIENT_NAME

import halo
import requests
Expand Down Expand Up @@ -557,24 +560,114 @@ def update_status():

class PutDatasetFilesCommand(BaseDatasetFilesCommand):

@classmethod
def _put(cls, path, url, content_type):
# @classmethod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this breaks anything; I haven't been able to at least

def _put(self, path, url, content_type, dataset_version_id=None, key=None):
size = os.path.getsize(path)
with requests.Session() as session:
headers = {'Content-Type': content_type}

try:
if size > 0:
if size <= 0:
headers.update({'Content-Size': '0'})
r = session.put(url, data='', headers=headers, timeout=5)
# for files under half a GB
elif size <= (10e8) / 2:
with open(path, 'rb') as f:
r = session.put(
url, data=f, headers=headers, timeout=5)
# # for chonky files, use a multipart upload
else:
headers.update({'Content-Size': '0'})
r = session.put(url, data='', headers=headers, timeout=5)

cls.validate_s3_response(r)
# Chunks need to be at least 5MB or AWS throws an
# EntityTooSmall error; we'll arbitrarily choose a
# 15MB chunksize
#
# Note also that AWS limits the max number of chunkc
# in a multipart upload to 10000, so this setting
# currently enforces a hard limit on 150GB per file.
#
# We can dynamically assign a larger part size if needed,
# but for the majority of use cases we should be fine
# as-is
part_minsize = int(15e6)
dataset_id, _, version = dataset_version_id.partition(":")
mpu_url = f'/datasets/{dataset_id}/versions/{version}/s3/preSignedUrls'

api_client = http_client.API(
api_url=config.CONFIG_HOST,
api_key=self.api_key,
ps_client_name=CLI_PS_CLIENT_NAME
)

mpu_create_res = api_client.post(
url=mpu_url,
json={
'datasetId': dataset_id,
'version': version,
'calls': [{
'method': 'createMultipartUpload',
'params': {'Key': key}
}]
}
)
mpu_data = json.loads(mpu_create_res.text)[0]['url']

parts = []
with open(path, 'rb') as f:
# we +2 the number of parts since we're doing floor
# division, which will cut off any trailing part
# less than the part_minsize, AND we want to 1-index
# our range to match what AWS expects for part
# numbers
for part in range(1, (size // part_minsize) + 2):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this'll also add an extra empty part if the upload is exactly divisible by 500MB, which will probably cause an error from AWS due to it being too small. But also 🤷

Copy link

@ghost ghost Aug 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This cost me my day. Use ceil...

presigned_url_res = api_client.post(
url=mpu_url,
json={
'datasetId': dataset_id,
'version': version,
'calls': [{
'method': 'uploadPart',
'params': {
'Key': key,
'UploadId': mpu_data['UploadId'],
'PartNumber': part
}
}]
}
)

presigned_url = json.loads(
presigned_url_res.text
)[0]['url']

chunk = f.read(part_minsize)
part_res = session.put(
presigned_url,
data=chunk,
timeout=5)
etag = part_res.headers['ETag'].replace('"', '')
parts.append({'ETag': etag, 'PartNumber': part})

r = api_client.post(
url=mpu_url,
json={
'datasetId': dataset_id,
'version': version,
'calls': [{
'method': 'completeMultipartUpload',
'params': {
'Key': key,
'UploadId': mpu_data['UploadId'],
'MultipartUpload': {'Parts': parts}
}
}]
}
)

self.validate_s3_response(r)
except requests.exceptions.ConnectionError as e:
return cls.report_connection_error(e)
return self.report_connection_error(e)
except Exception as e:
return e

@staticmethod
def _list_files(source_path):
Expand All @@ -599,8 +692,13 @@ def _sign_and_put(self, dataset_version_id, pool, results, update_status):

for pre_signed, result in zip(pre_signeds, results):
update_status()
pool.put(self._put, url=pre_signed.url,
path=result['path'], content_type=result['mimetype'])
pool.put(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Granted, this isn't really ideal. We're single-threading all parts of an upload in a single worker rather than distributing all N parts among all M workers in the pool. This will result in longer upload times, but that's better than the broken upload we have today. Soooo baby steps.

self._put,
url=pre_signed.url,
path=result['path'],
content_type=result['mimetype'],
dataset_version_id=dataset_version_id,
key=result['key'])

def execute(self, dataset_version_id, source_paths, target_path):
self.assert_supported(dataset_version_id)
Expand Down