Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
296 changes: 260 additions & 36 deletions labelbox/schema/data_row_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,19 @@ class _DeleteBatchDataRowMetadata(_CamelCaseMixin):
_BatchFunction = Callable[[_BatchInputs], List[DataRowMetadataBatchResponse]]


class _UpsertCustomMetadataSchemaEnumOptionInput(_CamelCaseMixin):
id: Optional[SchemaId]
name: constr(strip_whitespace=True, min_length=1, max_length=100)
kind: str


class _UpsertCustomMetadataSchemaInput(_CamelCaseMixin):
id: Optional[SchemaId]
name: constr(strip_whitespace=True, min_length=1, max_length=100)
kind: str
options: Optional[List[_UpsertCustomMetadataSchemaEnumOptionInput]]


class DataRowMetadataOntology:
""" Ontology for data row metadata

Expand Down Expand Up @@ -122,21 +135,30 @@ def _build_ontology(self):
f for f in self.fields if f.reserved
]
self.reserved_by_id = self._make_id_index(self.reserved_fields)
self.reserved_by_name: Dict[
str,
DataRowMetadataSchema] = self._make_name_index(self.reserved_fields)
self.reserved_by_name: Dict[str, Union[DataRowMetadataSchema, Dict[
str, DataRowMetadataSchema]]] = self._make_name_index(
self.reserved_fields)
self.reserved_by_name_normalized: Dict[
str, DataRowMetadataSchema] = self._make_normalized_name_index(
self.reserved_fields)

# custom fields
self.custom_fields: List[DataRowMetadataSchema] = [
f for f in self.fields if not f.reserved
]
self.custom_by_id = self._make_id_index(self.custom_fields)
self.custom_by_name: Dict[
self.custom_by_name: Dict[str, Union[DataRowMetadataSchema, Dict[
str,
DataRowMetadataSchema] = self._make_name_index(self.custom_fields)
DataRowMetadataSchema]]] = self._make_name_index(self.custom_fields)
self.custom_by_name_normalized: Dict[
str, DataRowMetadataSchema] = self._make_normalized_name_index(
self.custom_fields)

@staticmethod
def _make_name_index(fields: List[DataRowMetadataSchema]):
def _make_name_index(
fields: List[DataRowMetadataSchema]
) -> Dict[str, Union[DataRowMetadataSchema, Dict[str,
DataRowMetadataSchema]]]:
index = {}
for f in fields:
if f.options:
Expand All @@ -147,6 +169,15 @@ def _make_name_index(fields: List[DataRowMetadataSchema]):
index[f.name] = f
return index

@staticmethod
def _make_normalized_name_index(
fields: List[DataRowMetadataSchema]
) -> Dict[str, DataRowMetadataSchema]:
index = {}
for f in fields:
index[f.name] = f
return index

@staticmethod
def _make_id_index(
fields: List[DataRowMetadataSchema]
Expand Down Expand Up @@ -200,9 +231,144 @@ def _parse_ontology(raw_ontology) -> List[DataRowMetadataSchema]:
return fields

def refresh_ontology(self):
""" Update the `DataRowMetadataOntology` instance with the latest
metadata ontology schemas
"""
self._raw_ontology = self._get_ontology()
self._build_ontology()

def create_schema(self,
name: str,
kind: DataRowMetadataKind,
options: List[str] = None) -> DataRowMetadataSchema:
""" Create metadata schema

>>> mdo.create_schema(name, kind, options)

Args:
name (str): Name of metadata schema
kind (DataRowMetadataKind): Kind of metadata schema as `DataRowMetadataKind`
options (List[str]): List of Enum options

Returns:
Created metadata schema as `DataRowMetadataSchema`

Raises:
KeyError: When provided name is not a valid custom metadata
"""
if not isinstance(kind, DataRowMetadataKind):
raise ValueError(f"kind '{kind}' must be a `DataRowMetadataKind`")

upsert_schema = _UpsertCustomMetadataSchemaInput(name=name,
kind=kind.value)
if options:
if kind != DataRowMetadataKind.enum:
raise ValueError(
f"Kind '{kind}' must be an Enum, if Enum options are provided"
)
upsert_enum_options = [
_UpsertCustomMetadataSchemaEnumOptionInput(
name=o, kind=DataRowMetadataKind.option.value)
for o in options
]
upsert_schema.options = upsert_enum_options

return self._upsert_schema(upsert_schema)

def update_schema(self, name: str, new_name: str) -> DataRowMetadataSchema:
""" Update metadata schema

>>> mdo.update_schema(name, new_name)

Args:
name (str): Current name of metadata schema
new_name (str): New name of metadata schema

Returns:
Updated metadata schema as `DataRowMetadataSchema`

Raises:
KeyError: When provided name is not a valid custom metadata
"""
schema = self._validate_custom_schema_by_name(name)
upsert_schema = _UpsertCustomMetadataSchemaInput(id=schema.uid,
name=new_name,
kind=schema.kind.value)
if schema.options:
upsert_enum_options = [
_UpsertCustomMetadataSchemaEnumOptionInput(
id=o.uid,
name=o.name,
kind=DataRowMetadataKind.option.value)
for o in schema.options
]
upsert_schema.options = upsert_enum_options

return self._upsert_schema(upsert_schema)

def update_enum_option(self, name: str, option: str,
new_option: str) -> DataRowMetadataSchema:
""" Update Enum metadata schema option

>>> mdo.update_enum_option(name, option, new_option)

Args:
name (str): Name of metadata schema to update
option (str): Name of Enum option to update
new_option (str): New name of Enum option

Returns:
Updated metadata schema as `DataRowMetadataSchema`

Raises:
KeyError: When provided name is not a valid custom metadata
"""
schema = self._validate_custom_schema_by_name(name)
if schema.kind != DataRowMetadataKind.enum:
raise ValueError(
f"Updating Enum option is only supported for Enum metadata schema"
)

upsert_schema = _UpsertCustomMetadataSchemaInput(id=schema.uid,
name=schema.name,
kind=schema.kind.value)
upsert_enum_options = []
for o in schema.options:
enum_option = _UpsertCustomMetadataSchemaEnumOptionInput(
id=o.uid, name=o.name, kind=o.kind.value)
if enum_option.name == option:
enum_option.name = new_option
upsert_enum_options.append(enum_option)
upsert_schema.options = upsert_enum_options

return self._upsert_schema(upsert_schema)

def delete_schema(self, name: str) -> bool:
""" Delete metadata schema

>>> mdo.delete_schema(name)

Args:
name: Name of metadata schema to delete

Returns:
True if deletion is successful, False if unsuccessful

Raises:
KeyError: When provided name is not a valid custom metadata
"""
schema = self._validate_custom_schema_by_name(name)
query = """mutation DeleteCustomMetadataSchemaPyApi($where: WhereUniqueIdInput!) {
deleteCustomMetadataSchema(schema: $where){
success
}
}"""
res = self._client.execute(query, {'where': {
'id': schema.uid
}})['deleteCustomMetadataSchema']

return res['success']

def parse_metadata(
self, unparsed: List[Dict[str,
List[Union[str,
Expand Down Expand Up @@ -248,7 +414,7 @@ def parse_metadata_fields(

for f in unparsed:
if f["schemaId"] not in self.fields_by_id:
# Update metadata ontology if field can't be found
# Fetch latest metadata ontology if metadata can't be found
self.refresh_ontology()
if f["schemaId"] not in self.fields_by_id:
raise ValueError(
Expand Down Expand Up @@ -422,13 +588,69 @@ def _bulk_export(_data_row_ids: List[str]) -> List[DataRowMetadata]:
data_row_ids,
batch_size=self._batch_size)

def parse_upsert_metadata(self, metadata_fields) -> List[Dict[str, Any]]:
""" Converts either `DataRowMetadataField` or a dictionary representation
of `DataRowMetadataField` into a validated, flattened dictionary of
metadata fields that are used to create data row metadata. Used
internally in `Dataset.create_data_rows()`

Args:
metadata_fields: List of `DataRowMetadataField` or a dictionary representation
of `DataRowMetadataField`
Returns:
List of dictionaries representing a flattened view of metadata fields
"""

def _convert_metadata_field(metadata_field):
if isinstance(metadata_field, DataRowMetadataField):
return metadata_field
elif isinstance(metadata_field, dict):
if not all(key in metadata_field
for key in ("schema_id", "value")):
raise ValueError(
f"Custom metadata field '{metadata_field}' must have 'schema_id' and 'value' keys"
)
return DataRowMetadataField(
schema_id=metadata_field["schema_id"],
value=metadata_field["value"])
else:
raise ValueError(
f"Metadata field '{metadata_field}' is neither 'DataRowMetadataField' type or a dictionary"
)

# Convert all metadata fields to DataRowMetadataField type
metadata_fields = [_convert_metadata_field(m) for m in metadata_fields]
parsed_metadata = list(
chain.from_iterable(self._parse_upsert(m) for m in metadata_fields))
return [m.dict(by_alias=True) for m in parsed_metadata]

def _upsert_schema(
self, upsert_schema: _UpsertCustomMetadataSchemaInput
) -> DataRowMetadataSchema:
query = """mutation UpsertCustomMetadataSchemaPyApi($data: UpsertCustomMetadataSchemaInput!) {
upsertCustomMetadataSchema(data: $data){
id
name
kind
options {
id
name
kind
}
}
}"""
res = self._client.execute(
query, {"data": upsert_schema.dict(exclude_none=True)
})['upsertCustomMetadataSchema']
return _parse_metadata_schema(res)

def _parse_upsert(
self, metadatum: DataRowMetadataField
) -> List[_UpsertDataRowMetadataInput]:
"""Format for metadata upserts to GQL"""

if metadatum.schema_id not in self.fields_by_id:
# Update metadata ontology if field can't be found
# Fetch latest metadata ontology if metadata can't be found
self.refresh_ontology()
if metadatum.schema_id not in self.fields_by_id:
raise ValueError(
Expand All @@ -453,41 +675,14 @@ def _parse_upsert(

return [_UpsertDataRowMetadataInput(**p) for p in parsed]

# Convert metadata to DataRowMetadataField objects, parse all fields
# and return a dictionary of metadata fields for upsert
def parse_upsert_metadata(self, metadata_fields):

def _convert_metadata_field(metadata_field):
if isinstance(metadata_field, DataRowMetadataField):
return metadata_field
elif isinstance(metadata_field, dict):
if not all(key in metadata_field
for key in ("schema_id", "value")):
raise ValueError(
f"Custom metadata field '{metadata_field}' must have 'schema_id' and 'value' keys"
)
return DataRowMetadataField(
schema_id=metadata_field["schema_id"],
value=metadata_field["value"])
else:
raise ValueError(
f"Metadata field '{metadata_field}' is neither 'DataRowMetadataField' type or a dictionary"
)

# Convert all metadata fields to DataRowMetadataField type
metadata_fields = [_convert_metadata_field(m) for m in metadata_fields]
parsed_metadata = list(
chain.from_iterable(self._parse_upsert(m) for m in metadata_fields))
return [m.dict(by_alias=True) for m in parsed_metadata]

def _validate_delete(self, delete: DeleteDataRowMetadata):
if not len(delete.fields):
raise ValueError(f"No fields specified for {delete.data_row_id}")

deletes = set()
for schema_id in delete.fields:
if schema_id not in self.fields_by_id:
# Update metadata ontology if field can't be found
# Fetch latest metadata ontology if metadata can't be found
self.refresh_ontology()
if schema_id not in self.fields_by_id:
raise ValueError(
Expand All @@ -504,6 +699,16 @@ def _validate_delete(self, delete: DeleteDataRowMetadata):
data_row_id=delete.data_row_id,
schema_ids=list(delete.fields)).dict(by_alias=True)

def _validate_custom_schema_by_name(self,
name: str) -> DataRowMetadataSchema:
if name not in self.custom_by_name_normalized:
# Fetch latest metadata ontology if metadata can't be found
self.refresh_ontology()
if name not in self.custom_by_name_normalized:
raise KeyError(f"'{name}' is not a valid custom metadata")

return self.custom_by_name_normalized[name]


def _batch_items(iterable: List[Any], size: int) -> Generator[Any, None, None]:
l = len(iterable)
Expand Down Expand Up @@ -596,3 +801,22 @@ def _validate_enum_parse(
"schemaId": field.value,
"value": {}
}]


def _parse_metadata_schema(
unparsed: Dict[str, Union[str, List]]) -> DataRowMetadataSchema:
uid = unparsed['id']
name = unparsed['name']
kind = DataRowMetadataKind(unparsed['kind'])
options = [
DataRowMetadataSchema(uid=o['id'],
name=o['name'],
reserved=False,
kind=DataRowMetadataKind.option,
parent=uid) for o in unparsed['options']
]
return DataRowMetadataSchema(uid=uid,
name=name,
reserved=False,
kind=kind,
options=options or None)
2 changes: 1 addition & 1 deletion tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def graphql_url(environ: str) -> str:
if environ == Environ.PROD:
return 'https://api.labelbox.com/graphql'
elif environ == Environ.STAGING:
return 'https://staging-api.labelbox.com/graphql'
return 'https://api.lb-stage.xyz/graphql'
elif environ == Environ.ONPREM:
hostname = os.environ.get('LABELBOX_TEST_ONPREM_HOSTNAME', None)
if hostname is None:
Expand Down
Loading