Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions labelbox/schema/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ def export_data_rows(self, timeout_seconds=120) -> Generator:
response = requests.get(download_url)
response.raise_for_status()
reader = ndjson.reader(StringIO(response.text))
# TODO: Update result to parse customMetadata when resolver returns
# TODO: Update result to parse metadataFields when resolver returns
return (Entity.DataRow(self.client, {
**result, 'customMetadata': []
**result, 'metadataFields': []
}) for result in reader)
elif res["status"] == "FAILED":
raise LabelboxError("Data row export failed.")
Expand Down
8 changes: 4 additions & 4 deletions labelbox/schema/data_row.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class DataRow(DbObject, Updateable, BulkDeletable):
updated_at (datetime)
created_at (datetime)
media_attributes (dict): generated media attributes for the datarow
custom_metadata (list): metadata associated with the datarow
metadata_fields (list): metadata associated with the datarow

dataset (Relationship): `ToOne` relationship to Dataset
created_by (Relationship): `ToOne` relationship to User
Expand All @@ -35,11 +35,11 @@ class DataRow(DbObject, Updateable, BulkDeletable):
updated_at = Field.DateTime("updated_at")
created_at = Field.DateTime("created_at")
media_attributes = Field.Json("media_attributes")
custom_metadata = Field.List(
metadata_fields = Field.List(
DataRowMetadataField,
graphql_type="DataRowCustomMetadataUpsertInput!",
name="custom_metadata",
result_subquery="customMetadata { value schemaId }")
name="metadata_fields",
result_subquery="metadataFields { schemaId name value kind }")

# Relationships
dataset = Relationship.ToOne("Dataset")
Expand Down
45 changes: 31 additions & 14 deletions labelbox/schema/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,40 +52,57 @@ class Dataset(DbObject, Updateable, Deletable):
iam_integration = Relationship.ToOne("IAMIntegration", False,
"iam_integration", "signer")

def create_data_row(self, **kwargs) -> "DataRow":
def create_data_row(self, items=None, **kwargs) -> "DataRow":
""" Creates a single DataRow belonging to this dataset.

>>> dataset.create_data_row(row_data="http://my_site.com/photos/img_01.jpg")

Args:
items: Dictionary containing new `DataRow` data. At a minimum,
must contain `row_data` or `DataRow.row_data`.
**kwargs: Key-value arguments containing new `DataRow` data. At a minimum,
must contain `row_data`.

Raises:
InvalidQueryError: If both dictionary and `kwargs` are provided as inputs
InvalidQueryError: If `DataRow.row_data` field value is not provided
in `kwargs`.
InvalidAttributeError: in case the DB object type does not contain
any of the field names given in `kwargs`.

"""
invalid_argument_error = "Argument to create_data_row() must be either a dictionary, or kwargs containing `row_data` at minimum"

def convert_field_keys(items):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I maybe move this out of the function scope and into module scope because this seems like a more general utility function type thing and doesn't need to be re-defined every time this method is called

if not isinstance(items, dict):
raise InvalidQueryError(invalid_argument_error)
return {
key.name if isinstance(key, Field) else key: value
for key, value in items.items()
}

if items is not None and len(kwargs) > 0:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if items is maybe more concise and strict

raise InvalidQueryError(invalid_argument_error)

DataRow = Entity.DataRow
if DataRow.row_data.name not in kwargs:
args = convert_field_keys(items) if items is not None else kwargs

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if if items is a little more concise


if DataRow.row_data.name not in args:
raise InvalidQueryError(
"DataRow.row_data missing when creating DataRow.")

# If row data is a local file path, upload it to server.
row_data = kwargs[DataRow.row_data.name]
row_data = args[DataRow.row_data.name]
if os.path.exists(row_data):
kwargs[DataRow.row_data.name] = self.client.upload_file(row_data)
kwargs[DataRow.dataset.name] = self
args[DataRow.row_data.name] = self.client.upload_file(row_data)
args[DataRow.dataset.name] = self

# Parse metadata fields, if they are provided
if DataRow.custom_metadata.name in kwargs:
if DataRow.metadata_fields.name in args:
mdo = self.client.get_data_row_metadata_ontology()
kwargs[DataRow.custom_metadata.name] = mdo.parse_upsert_metadata(
kwargs[DataRow.custom_metadata.name])

return self.client._create(DataRow, kwargs)
args[DataRow.metadata_fields.name] = mdo.parse_upsert_metadata(
args[DataRow.metadata_fields.name])
return self.client._create(DataRow, args)

def create_data_rows_sync(self, items) -> None:
""" Synchronously bulk upload data rows.
Expand Down Expand Up @@ -264,10 +281,10 @@ def validate_attachments(item):
return attachments

def parse_metadata_fields(item):
metadata_fields = item.get('custom_metadata')
metadata_fields = item.get('metadata_fields')
if metadata_fields:
mdo = self.client.get_data_row_metadata_ontology()
item['custom_metadata'] = mdo.parse_upsert_metadata(
item['metadata_fields'] = mdo.parse_upsert_metadata(
metadata_fields)

def format_row(item):
Expand Down Expand Up @@ -413,9 +430,9 @@ def export_data_rows(self, timeout_seconds=120) -> Generator:
response = requests.get(download_url)
response.raise_for_status()
reader = ndjson.reader(StringIO(response.text))
# TODO: Update result to parse customMetadata when resolver returns
# TODO: Update result to parse metadataFields when resolver returns
return (Entity.DataRow(self.client, {
**result, 'customMetadata': []
**result, 'metadataFields': []
}) for result in reader)
elif res["status"] == "FAILED":
raise LabelboxError("Data row export failed.")
Expand Down
86 changes: 69 additions & 17 deletions tests/integration/test_data_rows.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import imghdr
from tempfile import NamedTemporaryFile
import uuid
import time
Expand Down Expand Up @@ -177,12 +178,63 @@ def test_data_row_single_creation(dataset, rand_gen, image_url):
assert requests.get(data_row_2.row_data).content == data


def test_create_data_row_with_dict(dataset, image_url):
client = dataset.client
assert len(list(dataset.data_rows())) == 0
dr = {"row_data": image_url}
data_row = dataset.create_data_row(dr)
assert len(list(dataset.data_rows())) == 1
assert data_row.dataset() == dataset
assert data_row.created_by() == client.get_user()
assert data_row.organization() == client.get_organization()
assert requests.get(image_url).content == \
requests.get(data_row.row_data).content
assert data_row.media_attributes is not None


def test_create_data_row_with_dict_containing_field(dataset, image_url):
client = dataset.client
assert len(list(dataset.data_rows())) == 0
dr = {DataRow.row_data: image_url}
data_row = dataset.create_data_row(dr)
assert len(list(dataset.data_rows())) == 1
assert data_row.dataset() == dataset
assert data_row.created_by() == client.get_user()
assert data_row.organization() == client.get_organization()
assert requests.get(image_url).content == \
requests.get(data_row.row_data).content
assert data_row.media_attributes is not None


def test_create_data_row_with_dict_unpacked(dataset, image_url):
client = dataset.client
assert len(list(dataset.data_rows())) == 0
dr = {"row_data": image_url}
data_row = dataset.create_data_row(**dr)
assert len(list(dataset.data_rows())) == 1
assert data_row.dataset() == dataset
assert data_row.created_by() == client.get_user()
assert data_row.organization() == client.get_organization()
assert requests.get(image_url).content == \
requests.get(data_row.row_data).content
assert data_row.media_attributes is not None


def test_create_data_row_with_invalid_input(dataset, image_url):
with pytest.raises(labelbox.exceptions.InvalidQueryError) as exc:
dataset.create_data_row("asdf")

dr = {"row_data": image_url}
with pytest.raises(labelbox.exceptions.InvalidQueryError) as exc:
dataset.create_data_row(dr, row_data=image_url)


def test_create_data_row_with_metadata(dataset, image_url):
client = dataset.client
assert len(list(dataset.data_rows())) == 0

data_row = dataset.create_data_row(row_data=image_url,
custom_metadata=make_metadata_fields())
metadata_fields=make_metadata_fields())

assert len(list(dataset.data_rows())) == 1
assert data_row.dataset() == dataset
Expand All @@ -191,8 +243,8 @@ def test_create_data_row_with_metadata(dataset, image_url):
assert requests.get(image_url).content == \
requests.get(data_row.row_data).content
assert data_row.media_attributes is not None
assert len(data_row.custom_metadata) == 5
assert [m["schemaId"] for m in data_row.custom_metadata
assert len(data_row.metadata_fields) == 4
assert [m["schemaId"] for m in data_row.metadata_fields
].sort() == EXPECTED_METADATA_SCHEMA_IDS


Expand All @@ -201,7 +253,7 @@ def test_create_data_row_with_metadata_dict(dataset, image_url):
assert len(list(dataset.data_rows())) == 0

data_row = dataset.create_data_row(
row_data=image_url, custom_metadata=make_metadata_fields_dict())
row_data=image_url, metadata_fields=make_metadata_fields_dict())

assert len(list(dataset.data_rows())) == 1
assert data_row.dataset() == dataset
Expand All @@ -210,8 +262,8 @@ def test_create_data_row_with_metadata_dict(dataset, image_url):
assert requests.get(image_url).content == \
requests.get(data_row.row_data).content
assert data_row.media_attributes is not None
assert len(data_row.custom_metadata) == 5
assert [m["schemaId"] for m in data_row.custom_metadata
assert len(data_row.metadata_fields) == 4
assert [m["schemaId"] for m in data_row.metadata_fields
].sort() == EXPECTED_METADATA_SCHEMA_IDS


Expand All @@ -221,7 +273,7 @@ def test_create_data_row_with_invalid_metadata(dataset, image_url):
DataRowMetadataField(schema_id=EMBEDDING_SCHEMA_ID, value=[0.0] * 128))

with pytest.raises(labelbox.exceptions.MalformedQueryException) as excinfo:
dataset.create_data_row(row_data=image_url, custom_metadata=fields)
dataset.create_data_row(row_data=image_url, metadata_fields=fields)


def test_create_data_rows_with_metadata(dataset, image_url):
Expand All @@ -232,22 +284,22 @@ def test_create_data_rows_with_metadata(dataset, image_url):
{
DataRow.row_data: image_url,
DataRow.external_id: "row1",
DataRow.custom_metadata: make_metadata_fields()
DataRow.metadata_fields: make_metadata_fields()
},
{
DataRow.row_data: image_url,
DataRow.external_id: "row2",
"custom_metadata": make_metadata_fields()
"metadata_fields": make_metadata_fields()
},
{
DataRow.row_data: image_url,
DataRow.external_id: "row3",
DataRow.custom_metadata: make_metadata_fields_dict()
DataRow.metadata_fields: make_metadata_fields_dict()
},
{
DataRow.row_data: image_url,
DataRow.external_id: "row4",
"custom_metadata": make_metadata_fields_dict()
"metadata_fields": make_metadata_fields_dict()
},
])
task.wait_till_done()
Expand All @@ -261,8 +313,8 @@ def test_create_data_rows_with_metadata(dataset, image_url):
assert requests.get(image_url).content == \
requests.get(row.row_data).content
assert row.media_attributes is not None
assert len(row.custom_metadata) == 5
assert [m["schemaId"] for m in row.custom_metadata
assert len(row.metadata_fields) == 4
assert [m["schemaId"] for m in row.metadata_fields
].sort() == EXPECTED_METADATA_SCHEMA_IDS


Expand All @@ -273,7 +325,7 @@ def test_create_data_rows_with_invalid_metadata(dataset, image_url):

task = dataset.create_data_rows([{
DataRow.row_data: image_url,
DataRow.custom_metadata: fields
DataRow.metadata_fields: fields
}])
task.wait_till_done()
assert task.status == "FAILED"
Expand All @@ -288,7 +340,7 @@ def test_create_data_rows_with_metadata_missing_value(dataset, image_url):
{
DataRow.row_data: image_url,
DataRow.external_id: "row1",
DataRow.custom_metadata: fields
DataRow.metadata_fields: fields
},
])

Expand All @@ -302,7 +354,7 @@ def test_create_data_rows_with_metadata_missing_schema_id(dataset, image_url):
{
DataRow.row_data: image_url,
DataRow.external_id: "row1",
DataRow.custom_metadata: fields
DataRow.metadata_fields: fields
},
])

Expand All @@ -316,7 +368,7 @@ def test_create_data_rows_with_metadata_wrong_type(dataset, image_url):
{
DataRow.row_data: image_url,
DataRow.external_id: "row1",
DataRow.custom_metadata: fields
DataRow.metadata_fields: fields
},
])

Expand Down