diff --git a/CHANGELOG.md b/CHANGELOG.md index dadf5e62f..1e5dbb113 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ # Changelog +# Next mea release +* Added `ModelRun.annotation_groups()` to fetch data rows and label information for a model run + # Version 2.5b0+mea (2021-06-11) ## Added * Added new `Model` and 'ModelRun` entities diff --git a/labelbox/client.py b/labelbox/client.py index 8c623eb8d..730636987 100644 --- a/labelbox/client.py +++ b/labelbox/client.py @@ -38,7 +38,8 @@ class Client: def __init__(self, api_key=None, endpoint='https://api.labelbox.com/graphql', - enable_experimental=False): + enable_experimental=False, + app_url="https://app.labelbox.com"): """ Creates and initializes a Labelbox Client. Logging is defaulted to level WARNING. To receive more verbose @@ -52,6 +53,7 @@ def __init__(self, api_key (str): API key. If None, the key is obtained from the "LABELBOX_API_KEY" environment variable. endpoint (str): URL of the Labelbox server to connect to. enable_experimental (bool): Indicates whether or not to use experimental features + app_url (str) : host url for all links to the web app Raises: labelbox.exceptions.AuthenticationError: If no `api_key` is provided as an argument or via the environment @@ -69,6 +71,7 @@ def __init__(self, logger.info("Experimental features have been enabled") logger.info("Initializing Labelbox client at '%s'", endpoint) + self.app_url = app_url # TODO: Make endpoints non-internal or support them as experimental self.endpoint = endpoint.replace('/graphql', '/_gql') diff --git a/labelbox/pagination.py b/labelbox/pagination.py index 15677f793..4ae68b452 100644 --- a/labelbox/pagination.py +++ b/labelbox/pagination.py @@ -95,6 +95,7 @@ def __init__(self, client: "Client", obj_class: Type["DbObject"], def get_page_data(self, results: Dict[str, Any]) -> List["DbObject"]: for deref in self.dereferencing: results = results[deref] + return [self.obj_class(self.client, result) for result in results] @abstractmethod diff --git a/labelbox/schema/annotation_import.py b/labelbox/schema/annotation_import.py index 5220dd93b..2b3801bcf 100644 --- a/labelbox/schema/annotation_import.py +++ b/labelbox/schema/annotation_import.py @@ -123,6 +123,9 @@ def _fetch_remote_ndjson(self, url: str) -> List[Dict[str, Any]]: Returns: ndjson as a list of dicts. """ + if self.state == AnnotationImportState.FAILED: + raise ValueError("Import failed.") + response = requests.get(url) response.raise_for_status() return ndjson.loads(response.text) diff --git a/labelbox/schema/model_run.py b/labelbox/schema/model_run.py index 26c5f057d..dd2fcb94a 100644 --- a/labelbox/schema/model_run.py +++ b/labelbox/schema/model_run.py @@ -1,6 +1,10 @@ -from labelbox.schema.annotation_import import AnnotationImport, MALPredictionImport, MEAPredictionImport -from pathlib import Path from typing import Dict, Iterable, Union +from pathlib import Path + +from labelbox.utils import uuid_to_cuid +from labelbox.pagination import PaginatedCollection +from labelbox.schema.annotation_import import MEAPredictionImport +from labelbox.orm.query import results_query_part from labelbox.orm.model import Field, Relationship from labelbox.orm.db_object import DbObject @@ -10,6 +14,7 @@ class ModelRun(DbObject): updated_at = Field.DateTime("updated_at") created_at = Field.DateTime("created_at") created_by_id = Field.String("created_by_id", "createdBy") + model_id = Field.String("model_id") def upsert_labels(self, label_ids): @@ -55,3 +60,33 @@ def add_predictions( else: raise ValueError( f'Invalid annotations given of type: {type(annotations)}') + + def annotation_groups(self): + query_str = """ + query modelRunPyApi($modelRunId: ID!, $from : String, $first: Int){ + annotationGroups(where: {modelRunId: {id: $modelRunId}}, after: $from, first: $first) + {nodes{%s},pageInfo{endCursor}} + } + """ % (results_query_part(AnnotationGroup)) + return PaginatedCollection( + self.client, query_str, {'modelRunId': self.uid}, + ['annotationGroups', 'nodes'], + lambda client, res: AnnotationGroup(client, self.model_id, res), + ['annotationGroups', 'pageInfo', 'endCursor']) + + +class AnnotationGroup(DbObject): + label_id = Field.String("label_id") + model_run_id = Field.String("model_run_id") + data_row = Relationship.ToOne("DataRow", False, cache=True) + + def __init__(self, client, model_id, field_values): + field_values['labelId'] = uuid_to_cuid(field_values['labelId']) + super().__init__(client, field_values) + self.model_id = model_id + + @property + def url(self): + app_url = self.client.app_url + endpoint = f"{app_url}/models/{self.model_id}/{self.model_run_id}/AllDatarowsSlice/{self.uid}?view=carousel" + return endpoint diff --git a/labelbox/utils.py b/labelbox/utils.py index b4c70b23e..b0fbdec64 100644 --- a/labelbox/utils.py +++ b/labelbox/utils.py @@ -1,4 +1,10 @@ import re +import uuid +import base36 + +_CUID_REGEX = r"^c[0-9a-z]{24}$" +MAX_SUPPORTED_CUID = "cy3mdbdhy3uqaqwzejcdh6akf" +MAX_SUPPORTED_UUID = "ffffffff-ffff-0fff-ffff-ffffffffffff" def _convert(s, sep, title): @@ -23,3 +29,38 @@ def title_case(s): def snake_case(s): """ Converts a string in [snake|camel|title]case to snake_case. """ return _convert(s, "_", lambda i: False) + + +def cuid_to_uuid(cuid: str) -> uuid.UUID: + if not re.match(_CUID_REGEX, cuid) or cuid > MAX_SUPPORTED_CUID: + raise ValueError("Invalid CUID: " + cuid) + + cleaned = cuid[1:] + + intermediate = 0 + for c in cleaned: + intermediate = intermediate * 36 + int(c, 36) + intermediate_str = f"{intermediate:x}" # int->str in hexadecimal + + padded = (32 - len(intermediate_str)) * '0' + intermediate_str + + return uuid.UUID("-".join((padded[1:9], padded[9:13], "0" + padded[13:16], + padded[16:20], padded[20:32]))) + + +def uuid_to_cuid(uuid: uuid.UUID) -> str: + cleaned = str(uuid).replace("-", "") + + if cleaned[12] != "0": + raise ValueError("Invalid UUID with non-zero version hex digit") + + cleaned = cleaned[0:12] + cleaned[13:] + + intermediate = 0 + for c in cleaned: + intermediate = intermediate * 16 + int(c, 16) + intermediate_str = base36.dumps(intermediate) # int->str in base36 + + padded = (24 - len(intermediate_str)) * '0' + intermediate_str + + return "c" + padded