Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog

# Next mea release
* Added `ModelRun.annotation_groups()` to fetch data rows and label information for a model run

# Version 2.5b0+mea (2021-06-11)
## Added
* Added new `Model` and 'ModelRun` entities
Expand Down
5 changes: 4 additions & 1 deletion labelbox/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class Client:
def __init__(self,
api_key=None,
endpoint='https://api.labelbox.com/graphql',
enable_experimental=False):
enable_experimental=False,
app_url="https://app.labelbox.com"):
""" Creates and initializes a Labelbox Client.

Logging is defaulted to level WARNING. To receive more verbose
Expand All @@ -52,6 +53,7 @@ def __init__(self,
api_key (str): API key. If None, the key is obtained from the "LABELBOX_API_KEY" environment variable.
endpoint (str): URL of the Labelbox server to connect to.
enable_experimental (bool): Indicates whether or not to use experimental features
app_url (str) : host url for all links to the web app
Raises:
labelbox.exceptions.AuthenticationError: If no `api_key`
is provided as an argument or via the environment
Expand All @@ -69,6 +71,7 @@ def __init__(self,
logger.info("Experimental features have been enabled")

logger.info("Initializing Labelbox client at '%s'", endpoint)
self.app_url = app_url

# TODO: Make endpoints non-internal or support them as experimental
self.endpoint = endpoint.replace('/graphql', '/_gql')
Expand Down
1 change: 1 addition & 0 deletions labelbox/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(self, client: "Client", obj_class: Type["DbObject"],
def get_page_data(self, results: Dict[str, Any]) -> List["DbObject"]:
for deref in self.dereferencing:
results = results[deref]

return [self.obj_class(self.client, result) for result in results]

@abstractmethod
Expand Down
3 changes: 3 additions & 0 deletions labelbox/schema/annotation_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ def _fetch_remote_ndjson(self, url: str) -> List[Dict[str, Any]]:
Returns:
ndjson as a list of dicts.
"""
if self.state == AnnotationImportState.FAILED:
raise ValueError("Import failed.")

response = requests.get(url)
response.raise_for_status()
return ndjson.loads(response.text)
Expand Down
39 changes: 37 additions & 2 deletions labelbox/schema/model_run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from labelbox.schema.annotation_import import AnnotationImport, MALPredictionImport, MEAPredictionImport
from pathlib import Path
from typing import Dict, Iterable, Union
from pathlib import Path

from labelbox.utils import uuid_to_cuid
from labelbox.pagination import PaginatedCollection
from labelbox.schema.annotation_import import MEAPredictionImport
from labelbox.orm.query import results_query_part
from labelbox.orm.model import Field, Relationship
from labelbox.orm.db_object import DbObject

Expand All @@ -10,6 +14,7 @@ class ModelRun(DbObject):
updated_at = Field.DateTime("updated_at")
created_at = Field.DateTime("created_at")
created_by_id = Field.String("created_by_id", "createdBy")
model_id = Field.String("model_id")

def upsert_labels(self, label_ids):

Expand Down Expand Up @@ -55,3 +60,33 @@ def add_predictions(
else:
raise ValueError(
f'Invalid annotations given of type: {type(annotations)}')

def annotation_groups(self):
query_str = """
query modelRunPyApi($modelRunId: ID!, $from : String, $first: Int){
annotationGroups(where: {modelRunId: {id: $modelRunId}}, after: $from, first: $first)
{nodes{%s},pageInfo{endCursor}}
}
""" % (results_query_part(AnnotationGroup))
return PaginatedCollection(
self.client, query_str, {'modelRunId': self.uid},
['annotationGroups', 'nodes'],
lambda client, res: AnnotationGroup(client, self.model_id, res),
['annotationGroups', 'pageInfo', 'endCursor'])


class AnnotationGroup(DbObject):
label_id = Field.String("label_id")
model_run_id = Field.String("model_run_id")
data_row = Relationship.ToOne("DataRow", False, cache=True)

def __init__(self, client, model_id, field_values):
field_values['labelId'] = uuid_to_cuid(field_values['labelId'])
super().__init__(client, field_values)
self.model_id = model_id

@property
def url(self):
app_url = self.client.app_url
endpoint = f"{app_url}/models/{self.model_id}/{self.model_run_id}/AllDatarowsSlice/{self.uid}?view=carousel"
return endpoint
41 changes: 41 additions & 0 deletions labelbox/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import re
import uuid
import base36

_CUID_REGEX = r"^c[0-9a-z]{24}$"
MAX_SUPPORTED_CUID = "cy3mdbdhy3uqaqwzejcdh6akf"
MAX_SUPPORTED_UUID = "ffffffff-ffff-0fff-ffff-ffffffffffff"


def _convert(s, sep, title):
Expand All @@ -23,3 +29,38 @@ def title_case(s):
def snake_case(s):
""" Converts a string in [snake|camel|title]case to snake_case. """
return _convert(s, "_", lambda i: False)


def cuid_to_uuid(cuid: str) -> uuid.UUID:
if not re.match(_CUID_REGEX, cuid) or cuid > MAX_SUPPORTED_CUID:
raise ValueError("Invalid CUID: " + cuid)

cleaned = cuid[1:]

intermediate = 0
for c in cleaned:
intermediate = intermediate * 36 + int(c, 36)
intermediate_str = f"{intermediate:x}" # int->str in hexadecimal

padded = (32 - len(intermediate_str)) * '0' + intermediate_str

return uuid.UUID("-".join((padded[1:9], padded[9:13], "0" + padded[13:16],
padded[16:20], padded[20:32])))


def uuid_to_cuid(uuid: uuid.UUID) -> str:
cleaned = str(uuid).replace("-", "")

if cleaned[12] != "0":
raise ValueError("Invalid UUID with non-zero version hex digit")

cleaned = cleaned[0:12] + cleaned[13:]

intermediate = 0
for c in cleaned:
intermediate = intermediate * 16 + int(c, 16)
intermediate_str = base36.dumps(intermediate) # int->str in base36

padded = (24 - len(intermediate_str)) * '0' + intermediate_str

return "c" + padded