Skip to content

Commit dbad88b

Browse files
author
Matt Sokoloff
committed
add annotation groups
1 parent 76e1c2e commit dbad88b

File tree

4 files changed

+74
-14
lines changed

4 files changed

+74
-14
lines changed

labelbox/client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ class Client:
3838
def __init__(self,
3939
api_key=None,
4040
endpoint='https://api.labelbox.com/graphql',
41-
enable_experimental=False):
41+
enable_experimental=False,
42+
app_url="https://app.labelbox.com"):
4243
""" Creates and initializes a Labelbox Client.
4344
4445
Logging is defaulted to level WARNING. To receive more verbose
@@ -52,6 +53,7 @@ def __init__(self,
5253
api_key (str): API key. If None, the key is obtained from the "LABELBOX_API_KEY" environment variable.
5354
endpoint (str): URL of the Labelbox server to connect to.
5455
enable_experimental (bool): Indicates whether or not to use experimental features
56+
app_url (str) : host url for all links to the web app
5557
Raises:
5658
labelbox.exceptions.AuthenticationError: If no `api_key`
5759
is provided as an argument or via the environment
@@ -69,6 +71,7 @@ def __init__(self,
6971
logger.info("Experimental features have been enabled")
7072

7173
logger.info("Initializing Labelbox client at '%s'", endpoint)
74+
self.app_url = app_url
7275

7376
# TODO: Make endpoints non-internal or support them as experimental
7477
self.endpoint = endpoint.replace('/graphql', '/_gql')

labelbox/schema/annotation_import.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ def _fetch_remote_ndjson(self, url: str) -> List[Dict[str, Any]]:
123123
Returns:
124124
ndjson as a list of dicts.
125125
"""
126+
if self.state == AnnotationImportState.FAILED:
127+
raise Exception("")
128+
126129
response = requests.get(url)
127130
response.raise_for_status()
128131
return ndjson.loads(response.text)

labelbox/schema/model_run.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from typing import Dict, Iterable, Union
22
from pathlib import Path
33

4+
from labelbox.utils import uuid_to_cuid
45
from labelbox.pagination import PaginatedCollection
56
from labelbox.schema.annotation_import import MEAPredictionImport
67
from labelbox.orm.query import results_query_part
7-
from labelbox.orm.model import Field, Entity
8+
from labelbox.orm.model import Field, Relationship
89
from labelbox.orm.db_object import DbObject
910

1011

@@ -13,6 +14,7 @@ class ModelRun(DbObject):
1314
updated_at = Field.DateTime("updated_at")
1415
created_at = Field.DateTime("created_at")
1516
created_by_id = Field.String("created_by_id", "createdBy")
17+
model_id = Field.String("model_id")
1618

1719
def upsert_labels(self, label_ids):
1820

@@ -59,22 +61,32 @@ def add_predictions(
5961
raise ValueError(
6062
f'Invalid annotations given of type: {type(annotations)}')
6163

62-
def data_rows(self):
64+
def annotation_groups(self):
6365
query_str = """
6466
query modelRunPyApi($modelRunId: ID!, $from : String, $first: Int){
6567
annotationGroups(where: {modelRunId: {id: $modelRunId}}, after: $from, first: $first)
66-
{
67-
nodes
68-
{
69-
dataRow {%s}
70-
},
71-
pageInfo{endCursor}
72-
}
68+
{nodes{%s},pageInfo{endCursor}}
7369
}
74-
""" % (results_query_part(Entity.DataRow))
75-
70+
""" % (results_query_part(AnnotationGroup))
7671
return PaginatedCollection(
7772
self.client, query_str, {'modelRunId': self.uid},
7873
['annotationGroups', 'nodes'],
79-
lambda client, response: Entity.DataRow(client, response[
80-
'dataRow']), ['annotationGroups', 'pageInfo', 'endCursor'])
74+
lambda client, res: AnnotationGroup(client, self.model_id, res),
75+
['annotationGroups', 'pageInfo', 'endCursor'])
76+
77+
78+
class AnnotationGroup(DbObject):
79+
label_id = Field.String("label_id")
80+
model_run_id = Field.String("model_run_id")
81+
data_row = Relationship.ToOne("DataRow", False, cache=True)
82+
83+
def __init__(self, client, model_id, field_values):
84+
field_values['labelId'] = uuid_to_cuid(field_values['labelId'])
85+
super().__init__(client, field_values)
86+
self.model_id = model_id
87+
88+
@property
89+
def url(self):
90+
app_url = self.client.app_url
91+
endpoint = f"{app_url}/models/{self.model_id}/{self.model_run_id}/AllDatarowsSlice/{self.uid}?view=carousel"
92+
return endpoint

labelbox/utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
11
import re
2+
import re
3+
import uuid
4+
import base36
5+
6+
_CUID_REGEX = r"^c[0-9a-z]{24}$"
7+
MAX_SUPPORTED_CUID = "cy3mdbdhy3uqaqwzejcdh6akf"
8+
MAX_SUPPORTED_UUID = "ffffffff-ffff-0fff-ffff-ffffffffffff"
29

310

411
def _convert(s, sep, title):
@@ -23,3 +30,38 @@ def title_case(s):
2330
def snake_case(s):
2431
""" Converts a string in [snake|camel|title]case to snake_case. """
2532
return _convert(s, "_", lambda i: False)
33+
34+
35+
def cuid_to_uuid(cuid: str) -> uuid.UUID:
36+
if not re.match(_CUID_REGEX, cuid) or cuid > MAX_SUPPORTED_CUID:
37+
raise ValueError("Invalid CUID: " + cuid)
38+
39+
cleaned = cuid[1:]
40+
41+
intermediate = 0
42+
for c in cleaned:
43+
intermediate = intermediate * 36 + int(c, 36)
44+
intermediate_str = f"{intermediate:x}" # int->str in hexadecimal
45+
46+
padded = (32 - len(intermediate_str)) * '0' + intermediate_str
47+
48+
return uuid.UUID("-".join((padded[1:9], padded[9:13], "0" + padded[13:16],
49+
padded[16:20], padded[20:32])))
50+
51+
52+
def uuid_to_cuid(uuid: uuid.UUID) -> str:
53+
cleaned = str(uuid).replace("-", "")
54+
55+
if cleaned[12] != "0":
56+
raise ValueError("Invalid UUID with non-zero version hex digit")
57+
58+
cleaned = cleaned[0:12] + cleaned[13:]
59+
60+
intermediate = 0
61+
for c in cleaned:
62+
intermediate = intermediate * 16 + int(c, 16)
63+
intermediate_str = base36.dumps(intermediate) # int->str in base36
64+
65+
padded = (24 - len(intermediate_str)) * '0' + intermediate_str
66+
67+
return "c" + padded

0 commit comments

Comments
 (0)