Skip to content

Commit

Permalink
feat(client): support uploading label on cloud callback
Browse files Browse the repository at this point in the history
PR Closed: #1014
  • Loading branch information
zhen.chen committed Sep 24, 2021
1 parent 2838f02 commit f0fe282
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 19 deletions.
30 changes: 13 additions & 17 deletions tensorbay/client/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,20 +259,6 @@ def _synchronize_upload_info(

self._client.open_api_do("PUT", "multi/callback", self._dataset_id, json=put_data)

def _import_cloud_file(
self,
cloud_path: str,
remote_path: str,
) -> None:

put_data: Dict[str, Any] = {
"segmentName": self.name,
"objects": [{"cloudPath": cloud_path, "remotePath": remote_path}],
"deleteSource": False,
}
put_data.update(self._status.get_status_info())
self._client.open_api_do("PUT", "multi/cloud-callback", self._dataset_id, json=put_data)

def _upload_label(self, data: Union[AuthData, Data]) -> None:
label = data.label.dumps()
if not label:
Expand Down Expand Up @@ -430,9 +416,19 @@ def import_auth_data(self, data: AuthData) -> None:
"""
self._status.check_authority_for_draft()

self._import_cloud_file(data.path, data.target_remote_path)
self._upload_label(data)
put_data: Dict[str, Any] = {
"segmentName": self.name,
"objects": [
{
"cloudPath": data.path,
"remotePath": data.target_remote_path,
"label": data.label.dumps(),
}
],
"deleteSource": False,
}
put_data.update(self._status.get_status_info())
self._client.open_api_do("PUT", "multi/cloud-callback", self._dataset_id, json=put_data)

def copy_data(
self,
Expand Down
7 changes: 5 additions & 2 deletions tests/test_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tensorbay import GAS
from tensorbay.dataset import Dataset, Segment
from tensorbay.exception import ResourceNotExistError
from tensorbay.label import Classification

from .utility import get_dataset_name

Expand Down Expand Up @@ -43,15 +44,17 @@ def test_import_cloud_files(self, accesskey, url, config_name):
dataset = Dataset(name=dataset_name)
segment = dataset.create_segment("Segment1")
for data in auth_data:
data.label.classification = Classification("cat", attributes={"color": "red"})
segment.append(data)

dataset_client = gas_client.upload_dataset(dataset, jobs=5)
dataset_client.commit("import data")
# dataset_client.commit("import data")

segment1 = Segment("Segment1", client=dataset_client)
assert len(segment1) == len(segment)
assert segment1[0].path == segment[0].path.split("/")[-1]
assert not segment1[0].label
assert segment1[0].label.classification.category == "cat"
assert segment1[0].label.classification.attributes["color"] == "red"

assert len(auth_data) == len(segment)

Expand Down

0 comments on commit f0fe282

Please sign in to comment.