Skip to content

Commit

Permalink
Merge 3cfa3e1 into 03e2373
Browse files Browse the repository at this point in the history
  • Loading branch information
graczhual committed Nov 15, 2021
2 parents 03e2373 + 3cfa3e1 commit 76c599e
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 33 deletions.
78 changes: 77 additions & 1 deletion tensorbay/client/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from tensorbay.client.requests import config
from tensorbay.client.status import Status
from tensorbay.dataset import AuthData, Data, Frame, RemoteData
from tensorbay.exception import FrameError, InvalidParamsError, ResponseError
from tensorbay.exception import FrameError, InvalidParamsError, ResourceNotExistError, ResponseError
from tensorbay.label import Label
from tensorbay.sensor.sensor import Sensor, Sensors
from tensorbay.utility import URL, FileMixin, chunked, locked
Expand Down Expand Up @@ -123,6 +123,24 @@ def _list_urls(self, offset: int = 0, limit: int = 128) -> Dict[str, Any]:
response = self._client.open_api_do("GET", "data/urls", self._dataset_id, params=params)
return response.json() # type: ignore[no-any-return]

def _get_data_details(self, remote_path: str) -> Dict[str, Any]:
params: Dict[str, Any] = {
"segmentName": self._name,
"remotePath": remote_path,
}
params.update(self._status.get_status_info())

if config.is_internal:
params["isInternal"] = True

response = self._client.open_api_do("GET", "data/details", self._dataset_id, params=params)
try:
data_details = response.json()["dataDetails"][0]
except IndexError as error:
raise ResourceNotExistError(resource="data", identification=remote_path) from error

return data_details # type: ignore[no-any-return]

def _list_data_details(self, offset: int = 0, limit: int = 128) -> Dict[str, Any]:
params: Dict[str, Any] = {
"segmentName": self._name,
Expand All @@ -137,6 +155,27 @@ def _list_data_details(self, offset: int = 0, limit: int = 128) -> Dict[str, Any
response = self._client.open_api_do("GET", "data/details", self._dataset_id, params=params)
return response.json() # type: ignore[no-any-return]

def _get_mask_url(self, mask_type: str, remote_path: str) -> str:
params: Dict[str, Any] = {
"segmentName": self._name,
"maskType": mask_type,
"remotePath": remote_path,
}
params.update(self._status.get_status_info())

if config.is_internal:
params["isInternal"] = True

response = self._client.open_api_do("GET", "masks/urls", self._dataset_id, params=params)
try:
mask_url = response.json()["urls"][0]["url"]
except IndexError as error:
raise ResourceNotExistError(
resource="{mask_type} of data", identification=remote_path
) from error

return mask_url # type: ignore[no-any-return]

def _list_mask_urls(self, mask_type: str, offset: int = 0, limit: int = 128) -> Dict[str, Any]:
params: Dict[str, Any] = {
"segmentName": self._name,
Expand Down Expand Up @@ -662,6 +701,43 @@ def list_data_paths(self) -> PagingList[str]:
"""
return PagingList(self._generate_data_paths, 128)

def get_data(self, remote_path: str) -> RemoteData:
"""Get required Data object from a dataset segment.
Arguments:
remote_path: The remote paths of the required data.
Returns:
:class:`~tensorbay.dataset.data.RemoteData`.
Raises:
ResourceNotExistError: When the required data does not exist.
"""
if not remote_path:
raise ResourceNotExistError(resource="data", identification=remote_path)

data_details = self._get_data_details(remote_path)
data = RemoteData.from_response_body(
data_details,
url=URL(data_details["url"], lambda: self._get_url(remote_path)),
cache_path=self._cache_path,
)
label = data.label

for key in _MASK_KEYS:
mask = getattr(label, key, None)
if mask:
mask.url = URL.from_getter(
lambda k=key.upper(), r=remote_path: self._get_mask_url(k, r),
lambda k=key.upper(), r=remote_path: ( # type: ignore[misc, arg-type]
self._get_mask_url(k, r)
),
)
mask.cache_path = os.path.join(self._cache_path, key, mask.path)

return data

def list_data(self) -> PagingList[RemoteData]:
"""List required Data object in a dataset segment.
Expand Down
167 changes: 135 additions & 32 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,43 +2,63 @@
#
# Copyright 2021 Graviti. Licensed under MIT License.
#
import os

import numpy as np
import pytest

from tensorbay import GAS
from tensorbay.dataset import Data, Frame
from tensorbay.label import Catalog, Label
from tensorbay.label import Catalog, InstanceMask, Label, PanopticMask, SemanticMask
from tensorbay.label.label_mask import RemoteInstanceMask, RemotePanopticMask, RemoteSemanticMask
from tensorbay.sensor import Sensor
from tests.utility import get_dataset_name

CATALOG_ATTRBUTES = [
{"name": "gender", "enum": ["male", "female"]},
{"name": "occluded", "type": "integer", "minimum": 1, "maximum": 5},
]
MASK_CATALOG_CONTENTS = {
"categories": [
{"name": "cat", "description": "This is an exmaple of test", "categoryId": 0},
{"name": "dog", "description": "This is an exmaple of test", "categoryId": 1},
],
"attributes": CATALOG_ATTRBUTES,
}
BOX2D_CATALOG_CONTENTS = {
"categories": [
{"name": "01"},
{"name": "02"},
{"name": "03"},
{"name": "04"},
{"name": "05"},
{"name": "06"},
{"name": "07"},
{"name": "08"},
{"name": "09"},
{"name": "10"},
{"name": "11"},
{"name": "12"},
{"name": "13"},
{"name": "14"},
{"name": "15"},
],
"attributes": [
{"name": "Vertical angle", "enum": [-90, -60, -30, -15, 0, 15, 30, 60, 90]},
{
"name": "Horizontal angle",
"enum": [-90, -75, -60, -45, -30, -15, 0, 15, 30, 45, 60, 75, 90],
},
{"name": "Serie", "enum": [1, 2]},
{"name": "Number", "type": "integer", "minimum": 0, "maximum": 92},
],
}
BOX2D_CATALOG = {"BOX2D": BOX2D_CATALOG_CONTENTS}
CATALOG = {
"BOX2D": {
"categories": [
{"name": "01"},
{"name": "02"},
{"name": "03"},
{"name": "04"},
{"name": "05"},
{"name": "06"},
{"name": "07"},
{"name": "08"},
{"name": "09"},
{"name": "10"},
{"name": "11"},
{"name": "12"},
{"name": "13"},
{"name": "14"},
{"name": "15"},
],
"attributes": [
{"name": "Vertical angle", "enum": [-90, -60, -30, -15, 0, 15, 30, 60, 90]},
{
"name": "Horizontal angle",
"enum": [-90, -75, -60, -45, -30, -15, 0, 15, 30, 45, 60, 75, 90],
},
{"name": "Serie", "enum": [1, 2]},
{"name": "Number", "type": "integer", "minimum": 0, "maximum": 92},
],
}
"BOX2D": BOX2D_CATALOG_CONTENTS,
"SEMANTIC_MASK": MASK_CATALOG_CONTENTS,
"INSTANCE_MASK": MASK_CATALOG_CONTENTS,
"PANOPTIC_MASK": MASK_CATALOG_CONTENTS,
}
LABEL = {
"BOX2D": [
Expand Down Expand Up @@ -66,9 +86,92 @@
"rotation": {"w": 1.0, "x": 2.0, "y": 3.0, "z": 4.0},
},
}
SEMANTIC_MASK_LABEL = {
"remotePath": "hello.png",
"info": [
{"categoryId": 0, "attributes": {"occluded": True}},
{"categoryId": 1, "attributes": {"occluded": False}},
],
}
INSTANCE_MASK_LABEL = {
"remotePath": "hello.png",
"info": [
{"instanceId": 0, "attributes": {"occluded": True}},
{"instanceId": 1, "attributes": {"occluded": False}},
],
}
PANOPTIC_MASK_LABEL = {
"remotePath": "hello.png",
"info": [
{"instanceId": 100, "categoryId": 0, "attributes": {"occluded": True}},
{"instanceId": 101, "categoryId": 1, "attributes": {"occluded": False}},
],
}


@pytest.fixture
def mask_file(tmp_path):
local_path = tmp_path / "hello.png"
mask = np.random.randint(0, 1, 48).reshape(8, 6)
mask.dump(local_path)
return local_path


class TestData:
def test_get_data(self, accesskey, url, tmp_path, mask_file):
gas_client = GAS(access_key=accesskey, url=url)
dataset_name = get_dataset_name()
dataset_client = gas_client.create_dataset(dataset_name)

dataset_client.create_draft("draft-1")
dataset_client.upload_catalog(Catalog.loads(CATALOG))
segment_client = dataset_client.get_or_create_segment("segment1")
path = tmp_path / "sub"
path.mkdir()

# Upload data with label
for i in range(10):
local_path = path / f"hello{i}.txt"
local_path.write_text(f"CONTENT{i}")
data = Data(local_path=str(local_path))
data.label = Label.loads(LABEL)

semantic_mask = SemanticMask(str(mask_file))
semantic_mask.all_attributes = {0: {"occluded": True}, 1: {"occluded": False}}
data.label.semantic_mask = semantic_mask

instance_mask = InstanceMask(str(mask_file))
instance_mask.all_attributes = {0: {"occluded": True}, 1: {"occluded": False}}
data.label.instance_mask = instance_mask

panoptic_mask = PanopticMask(str(mask_file))
panoptic_mask.all_category_ids = {100: 0, 101: 1}
data.label.panoptic_mask = panoptic_mask
segment_client.upload_data(data)

for i in range(10):
data = segment_client.get_data(f"hello{i}.txt")
assert data.path == f"hello{i}.txt"
assert data.label.box2d == Label.loads(LABEL).box2d

stem = os.path.splitext(data.path)[0]
remote_semantic_mask = data.label.semantic_mask
semantic_mask = RemoteSemanticMask.from_response_body(SEMANTIC_MASK_LABEL)
assert remote_semantic_mask.path == f"{stem}.png"
assert remote_semantic_mask.all_attributes == semantic_mask.all_attributes

remote_instance_mask = data.label.instance_mask
instance_mask = RemoteInstanceMask.from_response_body(INSTANCE_MASK_LABEL)
assert remote_instance_mask.path == f"{stem}.png"
assert remote_instance_mask.all_attributes == instance_mask.all_attributes

remote_panoptic_mask = data.label.panoptic_mask
panoptic_mask = RemotePanopticMask.from_response_body(PANOPTIC_MASK_LABEL)
assert remote_panoptic_mask.path == f"{stem}.png"
assert remote_panoptic_mask.all_category_ids == panoptic_mask.all_category_ids

gas_client.delete_dataset(dataset_name)

def test_list_file_order(self, accesskey, url, tmp_path):
gas_client = GAS(access_key=accesskey, url=url)
dataset_name = get_dataset_name()
Expand Down Expand Up @@ -154,7 +257,7 @@ def test_overwrite_label(self, accesskey, url, tmp_path):
dataset_name = get_dataset_name()
dataset_client = gas_client.create_dataset(dataset_name)
dataset_client.create_draft("draft-1")
dataset_client.upload_catalog(Catalog.loads(CATALOG))
dataset_client.upload_catalog(Catalog.loads(BOX2D_CATALOG))
segment_client = dataset_client.get_or_create_segment("segment1")
path = tmp_path / "sub"
path.mkdir()
Expand Down Expand Up @@ -182,7 +285,7 @@ def test_delete_data(self, accesskey, url, tmp_path):
dataset_name = get_dataset_name()
dataset_client = gas_client.create_dataset(dataset_name)
dataset_client.create_draft("draft-1")
dataset_client.upload_catalog(Catalog.loads(CATALOG))
dataset_client.upload_catalog(Catalog.loads(BOX2D_CATALOG))
segment_client = dataset_client.get_or_create_segment("segment1")

path = tmp_path / "sub"
Expand All @@ -205,7 +308,7 @@ def test_delete_frame(self, accesskey, url, tmp_path):
dataset_name = get_dataset_name()
dataset_client = gas_client.create_dataset(dataset_name, is_fusion=True)
dataset_client.create_draft("draft-1")
dataset_client.upload_catalog(Catalog.loads(CATALOG))
dataset_client.upload_catalog(Catalog.loads(BOX2D_CATALOG))
segment_client = dataset_client.get_or_create_segment("segment1")
segment_client.upload_sensor(Sensor.loads(LIDAR_DATA))

Expand Down

0 comments on commit 76c599e

Please sign in to comment.