From 0293dba2e16c74e55104e88b25c53a54158ee668 Mon Sep 17 00:00:00 2001 From: Matt Sokoloff Date: Wed, 16 Jun 2021 12:28:17 -0400 Subject: [PATCH 1/6] query for datarows --- labelbox/pagination.py | 1 + labelbox/schema/model_run.py | 27 +++++++++++++++++++++++++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/labelbox/pagination.py b/labelbox/pagination.py index 15677f793..4ae68b452 100644 --- a/labelbox/pagination.py +++ b/labelbox/pagination.py @@ -95,6 +95,7 @@ def __init__(self, client: "Client", obj_class: Type["DbObject"], def get_page_data(self, results: Dict[str, Any]) -> List["DbObject"]: for deref in self.dereferencing: results = results[deref] + return [self.obj_class(self.client, result) for result in results] @abstractmethod diff --git a/labelbox/schema/model_run.py b/labelbox/schema/model_run.py index 26c5f057d..33c78c451 100644 --- a/labelbox/schema/model_run.py +++ b/labelbox/schema/model_run.py @@ -1,7 +1,10 @@ -from labelbox.schema.annotation_import import AnnotationImport, MALPredictionImport, MEAPredictionImport +from labelbox.pagination import PaginatedCollection +from labelbox.schema.annotation_import import MEAPredictionImport from pathlib import Path +from labelbox.orm.model import Entity from typing import Dict, Iterable, Union -from labelbox.orm.model import Field, Relationship +from labelbox.orm.query import results_query_part +from labelbox.orm.model import Field from labelbox.orm.db_object import DbObject @@ -55,3 +58,23 @@ def add_predictions( else: raise ValueError( f'Invalid annotations given of type: {type(annotations)}') + + def data_rows(self): + query_str = """ + query modelRunPyApi($modelRunId: ID!, $from : String, $first: Int){ + annotationGroups(where: {modelRunId: {id: $modelRunId}}, after: $from, first: $first) + { + nodes + { + dataRow {%s} + }, + pageInfo{endCursor} + } + } + """ % (results_query_part(Entity.DataRow)) + + return PaginatedCollection( + self.client, query_str, {'modelRunId': self.uid}, + ['annotationGroups', 'nodes'], + lambda c, x: Entity.DataRow(c, x['dataRow']), + ['annotationGroups', 'pageInfo', 'endCursor']) From 76e1c2e89bbbf4ece7e9400f1958bf9f199feb46 Mon Sep 17 00:00:00 2001 From: Matt Sokoloff Date: Wed, 16 Jun 2021 12:32:36 -0400 Subject: [PATCH 2/6] clean up --- CHANGELOG.md | 3 +++ labelbox/schema/model_run.py | 12 ++++++------ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dadf5e62f..d49323fb3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ # Changelog +# Next mea release +* Added `ModelRun.data_rows()` to fetch data rows for a given model run + # Version 2.5b0+mea (2021-06-11) ## Added * Added new `Model` and 'ModelRun` entities diff --git a/labelbox/schema/model_run.py b/labelbox/schema/model_run.py index 33c78c451..3eeebe35b 100644 --- a/labelbox/schema/model_run.py +++ b/labelbox/schema/model_run.py @@ -1,10 +1,10 @@ +from typing import Dict, Iterable, Union +from pathlib import Path + from labelbox.pagination import PaginatedCollection from labelbox.schema.annotation_import import MEAPredictionImport -from pathlib import Path -from labelbox.orm.model import Entity -from typing import Dict, Iterable, Union from labelbox.orm.query import results_query_part -from labelbox.orm.model import Field +from labelbox.orm.model import Field, Entity from labelbox.orm.db_object import DbObject @@ -76,5 +76,5 @@ def data_rows(self): return PaginatedCollection( self.client, query_str, {'modelRunId': self.uid}, ['annotationGroups', 'nodes'], - lambda c, x: Entity.DataRow(c, x['dataRow']), - ['annotationGroups', 'pageInfo', 'endCursor']) + lambda client, response: Entity.DataRow(client, response[ + 'dataRow']), ['annotationGroups', 'pageInfo', 'endCursor']) From dbad88b173b8b7fa6ee4a85ca6bd018c2f9fe8d3 Mon Sep 17 00:00:00 2001 From: Matt Sokoloff Date: Wed, 16 Jun 2021 18:07:23 -0400 Subject: [PATCH 3/6] add annotation groups --- labelbox/client.py | 5 +++- labelbox/schema/annotation_import.py | 3 ++ labelbox/schema/model_run.py | 38 ++++++++++++++++--------- labelbox/utils.py | 42 ++++++++++++++++++++++++++++ 4 files changed, 74 insertions(+), 14 deletions(-) diff --git a/labelbox/client.py b/labelbox/client.py index 8c623eb8d..730636987 100644 --- a/labelbox/client.py +++ b/labelbox/client.py @@ -38,7 +38,8 @@ class Client: def __init__(self, api_key=None, endpoint='https://api.labelbox.com/graphql', - enable_experimental=False): + enable_experimental=False, + app_url="https://app.labelbox.com"): """ Creates and initializes a Labelbox Client. Logging is defaulted to level WARNING. To receive more verbose @@ -52,6 +53,7 @@ def __init__(self, api_key (str): API key. If None, the key is obtained from the "LABELBOX_API_KEY" environment variable. endpoint (str): URL of the Labelbox server to connect to. enable_experimental (bool): Indicates whether or not to use experimental features + app_url (str) : host url for all links to the web app Raises: labelbox.exceptions.AuthenticationError: If no `api_key` is provided as an argument or via the environment @@ -69,6 +71,7 @@ def __init__(self, logger.info("Experimental features have been enabled") logger.info("Initializing Labelbox client at '%s'", endpoint) + self.app_url = app_url # TODO: Make endpoints non-internal or support them as experimental self.endpoint = endpoint.replace('/graphql', '/_gql') diff --git a/labelbox/schema/annotation_import.py b/labelbox/schema/annotation_import.py index 5220dd93b..2df6fae1a 100644 --- a/labelbox/schema/annotation_import.py +++ b/labelbox/schema/annotation_import.py @@ -123,6 +123,9 @@ def _fetch_remote_ndjson(self, url: str) -> List[Dict[str, Any]]: Returns: ndjson as a list of dicts. """ + if self.state == AnnotationImportState.FAILED: + raise Exception("") + response = requests.get(url) response.raise_for_status() return ndjson.loads(response.text) diff --git a/labelbox/schema/model_run.py b/labelbox/schema/model_run.py index 3eeebe35b..dd2fcb94a 100644 --- a/labelbox/schema/model_run.py +++ b/labelbox/schema/model_run.py @@ -1,10 +1,11 @@ from typing import Dict, Iterable, Union from pathlib import Path +from labelbox.utils import uuid_to_cuid from labelbox.pagination import PaginatedCollection from labelbox.schema.annotation_import import MEAPredictionImport from labelbox.orm.query import results_query_part -from labelbox.orm.model import Field, Entity +from labelbox.orm.model import Field, Relationship from labelbox.orm.db_object import DbObject @@ -13,6 +14,7 @@ class ModelRun(DbObject): updated_at = Field.DateTime("updated_at") created_at = Field.DateTime("created_at") created_by_id = Field.String("created_by_id", "createdBy") + model_id = Field.String("model_id") def upsert_labels(self, label_ids): @@ -59,22 +61,32 @@ def add_predictions( raise ValueError( f'Invalid annotations given of type: {type(annotations)}') - def data_rows(self): + def annotation_groups(self): query_str = """ query modelRunPyApi($modelRunId: ID!, $from : String, $first: Int){ annotationGroups(where: {modelRunId: {id: $modelRunId}}, after: $from, first: $first) - { - nodes - { - dataRow {%s} - }, - pageInfo{endCursor} - } + {nodes{%s},pageInfo{endCursor}} } - """ % (results_query_part(Entity.DataRow)) - + """ % (results_query_part(AnnotationGroup)) return PaginatedCollection( self.client, query_str, {'modelRunId': self.uid}, ['annotationGroups', 'nodes'], - lambda client, response: Entity.DataRow(client, response[ - 'dataRow']), ['annotationGroups', 'pageInfo', 'endCursor']) + lambda client, res: AnnotationGroup(client, self.model_id, res), + ['annotationGroups', 'pageInfo', 'endCursor']) + + +class AnnotationGroup(DbObject): + label_id = Field.String("label_id") + model_run_id = Field.String("model_run_id") + data_row = Relationship.ToOne("DataRow", False, cache=True) + + def __init__(self, client, model_id, field_values): + field_values['labelId'] = uuid_to_cuid(field_values['labelId']) + super().__init__(client, field_values) + self.model_id = model_id + + @property + def url(self): + app_url = self.client.app_url + endpoint = f"{app_url}/models/{self.model_id}/{self.model_run_id}/AllDatarowsSlice/{self.uid}?view=carousel" + return endpoint diff --git a/labelbox/utils.py b/labelbox/utils.py index b4c70b23e..a7e63c688 100644 --- a/labelbox/utils.py +++ b/labelbox/utils.py @@ -1,4 +1,11 @@ import re +import re +import uuid +import base36 + +_CUID_REGEX = r"^c[0-9a-z]{24}$" +MAX_SUPPORTED_CUID = "cy3mdbdhy3uqaqwzejcdh6akf" +MAX_SUPPORTED_UUID = "ffffffff-ffff-0fff-ffff-ffffffffffff" def _convert(s, sep, title): @@ -23,3 +30,38 @@ def title_case(s): def snake_case(s): """ Converts a string in [snake|camel|title]case to snake_case. """ return _convert(s, "_", lambda i: False) + + +def cuid_to_uuid(cuid: str) -> uuid.UUID: + if not re.match(_CUID_REGEX, cuid) or cuid > MAX_SUPPORTED_CUID: + raise ValueError("Invalid CUID: " + cuid) + + cleaned = cuid[1:] + + intermediate = 0 + for c in cleaned: + intermediate = intermediate * 36 + int(c, 36) + intermediate_str = f"{intermediate:x}" # int->str in hexadecimal + + padded = (32 - len(intermediate_str)) * '0' + intermediate_str + + return uuid.UUID("-".join((padded[1:9], padded[9:13], "0" + padded[13:16], + padded[16:20], padded[20:32]))) + + +def uuid_to_cuid(uuid: uuid.UUID) -> str: + cleaned = str(uuid).replace("-", "") + + if cleaned[12] != "0": + raise ValueError("Invalid UUID with non-zero version hex digit") + + cleaned = cleaned[0:12] + cleaned[13:] + + intermediate = 0 + for c in cleaned: + intermediate = intermediate * 16 + int(c, 16) + intermediate_str = base36.dumps(intermediate) # int->str in base36 + + padded = (24 - len(intermediate_str)) * '0' + intermediate_str + + return "c" + padded From 06ac809b17e36ac839263cddc4ca2703ef3a3648 Mon Sep 17 00:00:00 2001 From: Matt Sokoloff Date: Wed, 16 Jun 2021 18:08:48 -0400 Subject: [PATCH 4/6] update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d49323fb3..1e5dbb113 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,7 @@ # Changelog # Next mea release -* Added `ModelRun.data_rows()` to fetch data rows for a given model run +* Added `ModelRun.annotation_groups()` to fetch data rows and label information for a model run # Version 2.5b0+mea (2021-06-11) ## Added From aadb162c45f7f86801086fbd4200c16a489ed0dc Mon Sep 17 00:00:00 2001 From: Matt Sokoloff Date: Wed, 16 Jun 2021 18:13:26 -0400 Subject: [PATCH 5/6] raise ValueError when user attempts to get status for failed import --- labelbox/schema/annotation_import.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/labelbox/schema/annotation_import.py b/labelbox/schema/annotation_import.py index 2df6fae1a..2b3801bcf 100644 --- a/labelbox/schema/annotation_import.py +++ b/labelbox/schema/annotation_import.py @@ -124,7 +124,7 @@ def _fetch_remote_ndjson(self, url: str) -> List[Dict[str, Any]]: ndjson as a list of dicts. """ if self.state == AnnotationImportState.FAILED: - raise Exception("") + raise ValueError("Import failed.") response = requests.get(url) response.raise_for_status() From 8d111fa134443c32a2f2bc88dfb164e80b1d3f0e Mon Sep 17 00:00:00 2001 From: Matt Sokoloff Date: Wed, 16 Jun 2021 18:43:30 -0400 Subject: [PATCH 6/6] recommended changes --- labelbox/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/labelbox/utils.py b/labelbox/utils.py index a7e63c688..b0fbdec64 100644 --- a/labelbox/utils.py +++ b/labelbox/utils.py @@ -1,5 +1,4 @@ import re -import re import uuid import base36