Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions labelbox/schema/data_row.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import TYPE_CHECKING
import json

from labelbox.orm import query
from labelbox.orm.db_object import DbObject, Updateable, BulkDeletable
Expand Down Expand Up @@ -64,6 +65,14 @@ def __init__(self, *args, **kwargs):
self.attachments.supports_filtering = False
self.attachments.supports_sorting = False

def update(self, **kwargs):
# Convert row data to string if it is an object
# All other updates pass through
row_data = kwargs.get("row_data")
if isinstance(row_data, dict):
kwargs['row_data'] = json.dumps(kwargs['row_data'])
super().update(**kwargs)

@staticmethod
def bulk_delete(data_rows) -> None:
""" Deletes all the given DataRows.
Expand Down
94 changes: 59 additions & 35 deletions labelbox/schema/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from labelbox.exceptions import InvalidQueryError, LabelboxError, ResourceNotFoundError, InvalidAttributeError
from labelbox.orm.db_object import DbObject, Updateable, Deletable
from labelbox.orm.model import Entity, Field, Relationship
from labelbox.orm import query
from labelbox.exceptions import MalformedQueryException

if TYPE_CHECKING:
Expand Down Expand Up @@ -95,18 +96,46 @@ def convert_field_keys(items):
raise InvalidQueryError(
"DataRow.row_data missing when creating DataRow.")

# If row data is a local file path, upload it to server.
row_data = args[DataRow.row_data.name]
if os.path.exists(row_data):
if not isinstance(row_data, str):
# If the row data is an object, upload as a string
args[DataRow.row_data.name] = json.dumps(row_data)
elif os.path.exists(row_data):
# If row data is a local file path, upload it to server.
args[DataRow.row_data.name] = self.client.upload_file(row_data)
args[DataRow.dataset.name] = self

# Parse metadata fields, if they are provided
if DataRow.metadata_fields.name in args:
mdo = self.client.get_data_row_metadata_ontology()
args[DataRow.metadata_fields.name] = mdo.parse_upsert_metadata(
args[DataRow.metadata_fields.name])
return self.client._create(DataRow, args)

query_str = """mutation CreateDataRowPyApi(
$row_data: String!,
$metadata_fields: [DataRowCustomMetadataUpsertInput!],
$attachments: [DataRowAttachmentInput!],
$media_type : MediaType,
$external_id : String,
$global_key : String,
$dataset: ID!
){
createDataRow(
data:
{
rowData: $row_data
mediaType: $media_type
metadataFields: $metadata_fields
externalId: $external_id
globalKey: $global_key
attachments: $attachments
dataset: {connect: {id: $dataset}}
}
)
{%s}
}
""" % query.results_query_part(Entity.DataRow)
res = self.client.execute(query_str, {**args, 'dataset': self.uid})
return DataRow(self.client, res['createDataRow'])

def create_data_rows_sync(self, items) -> None:
""" Synchronously bulk upload data rows.
Expand Down Expand Up @@ -229,8 +258,8 @@ def _create_descriptor_file(self, items, max_attachments_per_data_row=None):
>>> {DataRow.row_data:"http://my_site.com/photos/img_01.jpg"},
>>> {DataRow.row_data:"/path/to/file1.jpg"},
>>> "path/to/file2.jpg",
>>> {"tileLayerUrl" : "http://", ...}
>>> {"conversationalData" : [...], ...}
>>> {DataRow.row_data: {"tileLayerUrl" : "http://", ...}}
>>> {DataRow.row_data: {"type" : ..., 'version' : ..., 'messages' : [...]}}
>>> ])

For an example showing how to upload tiled data_rows see the following notebook:
Expand Down Expand Up @@ -258,7 +287,7 @@ def _create_descriptor_file(self, items, max_attachments_per_data_row=None):

def upload_if_necessary(item):
row_data = item['row_data']
if os.path.exists(row_data):
if isinstance(row_data, str) and os.path.exists(row_data):
item_url = self.client.upload_file(row_data)
item['row_data'] = item_url
if 'external_id' not in item:
Expand Down Expand Up @@ -341,40 +370,39 @@ def validate_keys(item):
"`row_data` missing when creating DataRow.")

invalid_keys = set(item) - {
*{f.name for f in DataRow.fields()}, 'attachments'
*{f.name for f in DataRow.fields()}, 'attachments', 'media_type'
}
if invalid_keys:
raise InvalidAttributeError(DataRow, invalid_keys)
return item

def formatLegacyConversationalData(item):
messages = item.pop("conversationalData")
version = item.pop("version", 1)
type = item.pop("type", "application/vnd.labelbox.conversational")
if "externalId" in item:
external_id = item.pop("externalId")
item["external_id"] = external_id
if "globalKey" in item:
global_key = item.pop("globalKey")
item["globalKey"] = global_key
validate_conversational_data(messages)
one_conversation = \
{
"type": type,
"version": version,
"messages": messages
}
item["row_data"] = one_conversation
return item

def convert_item(item):
# Don't make any changes to tms data
if "tileLayerUrl" in item:
validate_attachments(item)
return item

if "conversationalData" in item:
messages = item.pop("conversationalData")
version = item.pop("version")
type = item.pop("type")
if "externalId" in item:
external_id = item.pop("externalId")
item["external_id"] = external_id
if "globalKey" in item:
global_key = item.pop("globalKey")
item["globalKey"] = global_key
validate_conversational_data(messages)
one_conversation = \
{
"type": type,
"version": version,
"messages": messages
}
conversationUrl = self.client.upload_data(
json.dumps(one_conversation),
content_type="application/json",
filename="conversational_data.json")
item["row_data"] = conversationUrl
formatLegacyConversationalData(item)

# Convert all payload variations into the same dict format
item = format_row(item)
Expand All @@ -386,11 +414,7 @@ def convert_item(item):
parse_metadata_fields(item)
# Upload any local file paths
item = upload_if_necessary(item)

return {
"data" if key == "row_data" else utils.camel_case(key): value
for key, value in item.items()
}
return item

if not isinstance(items, Iterable):
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions labelbox/schema/media_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ class MediaType(Enum):

@classmethod
def _missing_(cls, name):
"""Handle missing null data types for projects
"""Handle missing null data types for projects
created without setting allowedMediaType
Handle upper case names for compatibility with
Handle upper case names for compatibility with
the GraphQL"""

if name is None:
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_data_row_media_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ def test_export_empty_media_attributes(configured_project_with_label):
sleep(10)
labels = project.label_generator()
label = next(labels)
assert label.data.media_attributes == {}
assert label.data.media_attributes == {}
134 changes: 134 additions & 0 deletions tests/integration/test_data_rows.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from tempfile import NamedTemporaryFile
import uuid
from datetime import datetime
import json

import pytest
import requests
Expand Down Expand Up @@ -28,6 +29,56 @@ def mdo(client):
yield mdo


@pytest.fixture
def conversational_content():
return {
'row_data': {
"messages": [{
"messageId": "message-0",
"timestampUsec": 1530718491,
"content": "I love iphone! i just bought new iphone! 🥰 📲",
"user": {
"userId": "Bot 002",
"name": "Bot"
},
"align": "left",
"canLabel": False
}],
"version": 1,
"type": "application/vnd.labelbox.conversational"
}
}


@pytest.fixture
def tile_content():
return {
"row_data": {
"tileLayerUrl":
"https://s3-us-west-1.amazonaws.com/lb-tiler-layers/mexico_city/{z}/{x}/{y}.png",
"bounds": [[19.405662413477728, -99.21052827588443],
[19.400498983095076, -99.20534818927473]],
"minZoom":
12,
"maxZoom":
20,
"epsg":
"EPSG4326",
"alternativeLayers": [{
"tileLayerUrl":
"https://api.mapbox.com/styles/v1/mapbox/satellite-streets-v11/tiles/{z}/{x}/{y}?access_token=pk.eyJ1IjoibWFwYm94IiwiYSI6ImNpejY4NXVycTA2emYycXBndHRqcmZ3N3gifQ.rJcFIG214AriISLbB6B5aw",
"name":
"Satellite"
}, {
"tileLayerUrl":
"https://api.mapbox.com/styles/v1/mapbox/navigation-guidance-night-v4/tiles/{z}/{x}/{y}?access_token=pk.eyJ1IjoibWFwYm94IiwiYSI6ImNpejY4NXVycTA2emYycXBndHRqcmZ3N3gifQ.rJcFIG214AriISLbB6B5aw",
"name":
"Guidance"
}]
}
}


def make_metadata_fields():
embeddings = [0.0] * 128
msg = "A message"
Expand Down Expand Up @@ -408,6 +459,18 @@ def test_data_row_update(dataset, rand_gen, image_url):
data_row.update(external_id=external_id_2)
assert data_row.external_id == external_id_2

in_line_content = "123"
data_row.update(row_data=in_line_content)
assert requests.get(data_row.row_data).text == in_line_content

data_row.update(row_data=image_url)
assert data_row.row_data == image_url

# tileLayer becomes a media attribute
pdf_url = "http://somepdfurl"
data_row.update(row_data={'pdfUrl': pdf_url, "tileLayerUrl": "123"})
assert data_row.row_data == pdf_url


def test_data_row_filtering_sorting(dataset, image_url):
task = dataset.create_data_rows([
Expand Down Expand Up @@ -696,3 +759,74 @@ def test_data_row_rulk_creation_sync_with_same_global_keys(

assert len(list(dataset.data_rows())) == 1
assert list(dataset.data_rows())[0].global_key == global_key_1


def test_create_conversational_text(dataset, conversational_content):
examples = [
{
**conversational_content, 'media_type': 'CONVERSATIONAL'
},
conversational_content,
{
"conversationalData": conversational_content['row_data']['messages']
} # Old way to check for backwards compatibility
]
dataset.create_data_rows_sync(examples)
data_rows = list(dataset.data_rows())
assert len(data_rows) == len(examples)
for data_row in data_rows:
assert requests.get(
data_row.row_data).json() == conversational_content['row_data']


def test_invalid_media_type(dataset, conversational_content):
for error_message, invalid_media_type in [[
"Found invalid contents for media type: 'IMAGE'", 'IMAGE'
], ["Found invalid media type: 'totallyinvalid'", 'totallyinvalid']]:
# TODO: What error kind should this be? It looks like for global key we are
# using malformed query. But for invalid contents in FileUploads we use InvalidQueryError
with pytest.raises(labelbox.exceptions.InvalidQueryError):
dataset.create_data_rows_sync([{
**conversational_content, 'media_type': invalid_media_type
}])

task = dataset.create_data_rows([{
**conversational_content, 'media_type': invalid_media_type
}])
task.wait_till_done()
assert task.errors == {'message': error_message}


def test_create_tiled_layer(dataset, tile_content):
examples = [
{
**tile_content, 'media_type': 'TMS_SIMPLE'
},
tile_content,
tile_content['row_data'] # Old way to check for backwards compatibility
]
dataset.create_data_rows_sync(examples)
data_rows = list(dataset.data_rows())
assert len(data_rows) == len(examples)
for data_row in data_rows:
assert json.loads(data_row.row_data) == tile_content['row_data']


def test_create_data_row_with_attachments(dataset):
attachment_value = 'attachment value'
dr = dataset.create_data_row(row_data="123",
attachments=[{
'type': 'TEXT',
'value': attachment_value
}])
attachments = list(dr.attachments())
assert len(attachments) == 1


def test_create_data_row_with_media_type(dataset, image_url):
with pytest.raises(labelbox.exceptions.InvalidQueryError) as exc:
dr = dataset.create_data_row(
row_data={'invalid_object': 'invalid_value'}, media_type="IMAGE")
assert "Found invalid contents for media type: \'IMAGE\'" in str(exc.value)

dataset.create_data_row(row_data=image_url, media_type="IMAGE")
Loading