Skip to content

Commit

Permalink
feat(client): rebulid uploading data procedure
Browse files Browse the repository at this point in the history
  • Loading branch information
rexzheng324-c committed Jun 7, 2021
1 parent bd92d9f commit 43e3f5a
Showing 1 changed file with 37 additions and 31 deletions.
68 changes: 37 additions & 31 deletions tensorbay/client/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
import os
import time
from copy import deepcopy
from hashlib import sha1
from itertools import islice
from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union

import filetype
from requests_toolbelt import MultipartEncoder
Expand All @@ -44,12 +45,6 @@
from .dataset import DatasetClient, FusionDatasetClient


_SERVER_VERSION_MATCH: Dict[str, str] = {
"AmazonS3": "x-amz-version-id",
"AliyunOSS": "x-oss-version-id",
}


class SegmentClientBase: # pylint: disable=too-many-instance-attributes
"""This class defines the basic concept of :class:`SegmentClient`.
Expand Down Expand Up @@ -79,6 +74,14 @@ def __init__(
self._client = dataset_client._client # pylint: disable=protected-access
self._permission: Dict[str, Any] = {"expireAt": 0}

@staticmethod
def _calculate_file_sha1(local_path: str) -> str:
sha1_obj = sha1()
with open(local_path, "rb") as fp:
sha1_obj.update(fp.read())

return sha1_obj.hexdigest()

def _get_url(self, remote_path: str) -> str:
"""Get URL of a specific remote path.
Expand Down Expand Up @@ -136,54 +139,50 @@ def _post_multipart_formdata(
local_path: str,
remote_path: str,
data: Dict[str, Any],
) -> Tuple[str, str]:
) -> None:
with open(local_path, "rb") as fp:
file_type = filetype.guess_mime(local_path)
if "x-amz-date" in data:
data["Content-Type"] = file_type
data["file"] = (remote_path, fp, file_type)
multipart = MultipartEncoder(data)
response_headers = self._client.do(

self._client.do(
"POST", url, data=multipart, headers={"Content-Type": multipart.content_type}
).headers
version = _SERVER_VERSION_MATCH[response_headers["Server"]]
return response_headers[version], response_headers["ETag"].strip('"')
)

def _put_binary_file_to_azure(
self,
url: str,
local_path: str,
data: Dict[str, Any],
) -> Tuple[str, str]:
) -> None:
with open(local_path, "rb") as fp:
file_type = filetype.guess_mime(local_path)
request_headers = {
"x-ms-blob-content-type": file_type,
"x-ms-blob-type": data["x-ms-blob-type"],
}
response_headers = self._client.do("PUT", url, data=fp, headers=request_headers).headers
return response_headers["x-ms-version-id"], response_headers["ETag"].strip('"')
self._client.do("PUT", url, data=fp, headers=request_headers)

def _synchronize_upload_info( # pylint: disable=too-many-arguments
self,
key: str,
version_id: str,
etag: str,
remote_path: str,
checksum: str,
frame_info: Optional[Dict[str, Any]] = None,
skip_uploaded_files: bool = False,
) -> None:
put_data: Dict[str, Any] = {
"key": key,
"versionId": version_id,
"etag": etag,
"segmentName": self.name,
"objects": [{"checksum": checksum, "remotePath": remote_path}],
}
put_data.update(self._status.get_status_info())

if frame_info:
put_data.update(frame_info)
put_data["objects"][0].update(frame_info)

try:
self._client.open_api_do("PUT", "callback", self._dataset_id, json=put_data)
self._client.open_api_do("PUT", "multi/callback", self._dataset_id, json=put_data)
except ResponseSystemError:
if not skip_uploaded_files:
raise
Expand Down Expand Up @@ -306,7 +305,11 @@ def upload_file(self, local_path: str, target_remote_path: str = "") -> None:

permission = self._get_upload_permission()
post_data = permission["result"]
post_data["key"] = permission["extra"]["objectPrefix"] + target_remote_path
del post_data["multipleUploadLimit"]

checksum = self._calculate_file_sha1(local_path)

post_data["key"] = checksum

backend_type = permission["extra"]["backendType"]
if backend_type == "azure":
Expand All @@ -315,16 +318,16 @@ def upload_file(self, local_path: str, target_remote_path: str = "") -> None:
f'{target_remote_path}?{permission["result"]["token"]}'
)

version_id, etag = self._put_binary_file_to_azure(url, local_path, post_data)
self._put_binary_file_to_azure(url, local_path, post_data)
else:
version_id, etag = self._post_multipart_formdata(
self._post_multipart_formdata(
permission["extra"]["host"],
local_path,
target_remote_path,
post_data,
)

self._synchronize_upload_info(post_data["key"], version_id, etag)
self._synchronize_upload_info(target_remote_path, checksum)

def upload_label(self, data: Data) -> None:
"""Upload label with Data object to the draft.
Expand Down Expand Up @@ -484,7 +487,10 @@ def upload_frame( # pylint: disable=too-many-locals

permission = self._get_upload_permission()
post_data = permission["result"]
post_data["key"] = permission["extra"]["objectPrefix"] + target_remote_path

checksum = self._calculate_file_sha1(data.path)

post_data["key"] = checksum

backend_type = permission["extra"]["backendType"]
if backend_type == "azure":
Expand All @@ -493,9 +499,9 @@ def upload_frame( # pylint: disable=too-many-locals
f'{target_remote_path}?{permission["result"]["token"]}'
)

version_id, etag = self._put_binary_file_to_azure(url, data.path, post_data)
self._put_binary_file_to_azure(url, data.path, post_data)
else:
version_id, etag = self._post_multipart_formdata(
self._post_multipart_formdata(
permission["extra"]["host"],
data.path,
target_remote_path,
Expand All @@ -511,7 +517,7 @@ def upload_frame( # pylint: disable=too-many-locals
frame_info["timestamp"] = data.timestamp

self._synchronize_upload_info(
post_data["key"], version_id, etag, frame_info, skip_uploaded_files
target_remote_path, checksum, frame_info, skip_uploaded_files
)

self._upload_label(data)
Expand Down

0 comments on commit 43e3f5a

Please sign in to comment.