Skip to content

Commit 44c257c

Browse files
authored
Merge pull request #173 from Labelbox/ms/mea-model-run-datarow
mea query for datarows
2 parents 470949e + 8d111fa commit 44c257c

File tree

6 files changed

+89
-3
lines changed

6 files changed

+89
-3
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Changelog
22

3+
# Next mea release
4+
* Added `ModelRun.annotation_groups()` to fetch data rows and label information for a model run
5+
36
# Version 2.5b0+mea (2021-06-11)
47
## Added
58
* Added new `Model` and 'ModelRun` entities

labelbox/client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ class Client:
3838
def __init__(self,
3939
api_key=None,
4040
endpoint='https://api.labelbox.com/graphql',
41-
enable_experimental=False):
41+
enable_experimental=False,
42+
app_url="https://app.labelbox.com"):
4243
""" Creates and initializes a Labelbox Client.
4344
4445
Logging is defaulted to level WARNING. To receive more verbose
@@ -52,6 +53,7 @@ def __init__(self,
5253
api_key (str): API key. If None, the key is obtained from the "LABELBOX_API_KEY" environment variable.
5354
endpoint (str): URL of the Labelbox server to connect to.
5455
enable_experimental (bool): Indicates whether or not to use experimental features
56+
app_url (str) : host url for all links to the web app
5557
Raises:
5658
labelbox.exceptions.AuthenticationError: If no `api_key`
5759
is provided as an argument or via the environment
@@ -69,6 +71,7 @@ def __init__(self,
6971
logger.info("Experimental features have been enabled")
7072

7173
logger.info("Initializing Labelbox client at '%s'", endpoint)
74+
self.app_url = app_url
7275

7376
# TODO: Make endpoints non-internal or support them as experimental
7477
self.endpoint = endpoint.replace('/graphql', '/_gql')

labelbox/pagination.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def __init__(self, client: "Client", obj_class: Type["DbObject"],
9595
def get_page_data(self, results: Dict[str, Any]) -> List["DbObject"]:
9696
for deref in self.dereferencing:
9797
results = results[deref]
98+
9899
return [self.obj_class(self.client, result) for result in results]
99100

100101
@abstractmethod

labelbox/schema/annotation_import.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ def _fetch_remote_ndjson(self, url: str) -> List[Dict[str, Any]]:
123123
Returns:
124124
ndjson as a list of dicts.
125125
"""
126+
if self.state == AnnotationImportState.FAILED:
127+
raise ValueError("Import failed.")
128+
126129
response = requests.get(url)
127130
response.raise_for_status()
128131
return ndjson.loads(response.text)

labelbox/schema/model_run.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
from labelbox.schema.annotation_import import AnnotationImport, MALPredictionImport, MEAPredictionImport
2-
from pathlib import Path
31
from typing import Dict, Iterable, Union
2+
from pathlib import Path
3+
4+
from labelbox.utils import uuid_to_cuid
5+
from labelbox.pagination import PaginatedCollection
6+
from labelbox.schema.annotation_import import MEAPredictionImport
7+
from labelbox.orm.query import results_query_part
48
from labelbox.orm.model import Field, Relationship
59
from labelbox.orm.db_object import DbObject
610

@@ -10,6 +14,7 @@ class ModelRun(DbObject):
1014
updated_at = Field.DateTime("updated_at")
1115
created_at = Field.DateTime("created_at")
1216
created_by_id = Field.String("created_by_id", "createdBy")
17+
model_id = Field.String("model_id")
1318

1419
def upsert_labels(self, label_ids):
1520

@@ -55,3 +60,33 @@ def add_predictions(
5560
else:
5661
raise ValueError(
5762
f'Invalid annotations given of type: {type(annotations)}')
63+
64+
def annotation_groups(self):
65+
query_str = """
66+
query modelRunPyApi($modelRunId: ID!, $from : String, $first: Int){
67+
annotationGroups(where: {modelRunId: {id: $modelRunId}}, after: $from, first: $first)
68+
{nodes{%s},pageInfo{endCursor}}
69+
}
70+
""" % (results_query_part(AnnotationGroup))
71+
return PaginatedCollection(
72+
self.client, query_str, {'modelRunId': self.uid},
73+
['annotationGroups', 'nodes'],
74+
lambda client, res: AnnotationGroup(client, self.model_id, res),
75+
['annotationGroups', 'pageInfo', 'endCursor'])
76+
77+
78+
class AnnotationGroup(DbObject):
79+
label_id = Field.String("label_id")
80+
model_run_id = Field.String("model_run_id")
81+
data_row = Relationship.ToOne("DataRow", False, cache=True)
82+
83+
def __init__(self, client, model_id, field_values):
84+
field_values['labelId'] = uuid_to_cuid(field_values['labelId'])
85+
super().__init__(client, field_values)
86+
self.model_id = model_id
87+
88+
@property
89+
def url(self):
90+
app_url = self.client.app_url
91+
endpoint = f"{app_url}/models/{self.model_id}/{self.model_run_id}/AllDatarowsSlice/{self.uid}?view=carousel"
92+
return endpoint

labelbox/utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
import re
2+
import uuid
3+
import base36
4+
5+
_CUID_REGEX = r"^c[0-9a-z]{24}$"
6+
MAX_SUPPORTED_CUID = "cy3mdbdhy3uqaqwzejcdh6akf"
7+
MAX_SUPPORTED_UUID = "ffffffff-ffff-0fff-ffff-ffffffffffff"
28

39

410
def _convert(s, sep, title):
@@ -23,3 +29,38 @@ def title_case(s):
2329
def snake_case(s):
2430
""" Converts a string in [snake|camel|title]case to snake_case. """
2531
return _convert(s, "_", lambda i: False)
32+
33+
34+
def cuid_to_uuid(cuid: str) -> uuid.UUID:
35+
if not re.match(_CUID_REGEX, cuid) or cuid > MAX_SUPPORTED_CUID:
36+
raise ValueError("Invalid CUID: " + cuid)
37+
38+
cleaned = cuid[1:]
39+
40+
intermediate = 0
41+
for c in cleaned:
42+
intermediate = intermediate * 36 + int(c, 36)
43+
intermediate_str = f"{intermediate:x}" # int->str in hexadecimal
44+
45+
padded = (32 - len(intermediate_str)) * '0' + intermediate_str
46+
47+
return uuid.UUID("-".join((padded[1:9], padded[9:13], "0" + padded[13:16],
48+
padded[16:20], padded[20:32])))
49+
50+
51+
def uuid_to_cuid(uuid: uuid.UUID) -> str:
52+
cleaned = str(uuid).replace("-", "")
53+
54+
if cleaned[12] != "0":
55+
raise ValueError("Invalid UUID with non-zero version hex digit")
56+
57+
cleaned = cleaned[0:12] + cleaned[13:]
58+
59+
intermediate = 0
60+
for c in cleaned:
61+
intermediate = intermediate * 16 + int(c, 16)
62+
intermediate_str = base36.dumps(intermediate) # int->str in base36
63+
64+
padded = (24 - len(intermediate_str)) * '0' + intermediate_str
65+
66+
return "c" + padded

0 commit comments

Comments
 (0)