Skip to content

Commit c242a1e

Browse files
authored
Merge pull request #163 from Labelbox/ms/annotation-import
annotation import
2 parents 1d9924b + b778959 commit c242a1e

File tree

13 files changed

+1237
-32
lines changed

13 files changed

+1237
-32
lines changed

CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
# Next Release
44
## Added
5-
* Added new `Model` schema
5+
* Added new `Model` and 'ModelRun` entities
6+
* Update client to support creating and querying for `Model`s
7+
* Implement new prediction import pipeline to support both MAL and MEA
8+
* Added notebook to demonstrate how to use MEA
69

710
# Version 2.5.6 (2021-05-19)
811
## Fix

examples/model_assisted_labeling/image_mea.ipynb

Lines changed: 693 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
from labelbox import Client
2+
3+
from typing import Dict, Any, Tuple
4+
from skimage import measure
5+
from io import BytesIO
6+
from PIL import Image
7+
import numpy as np
8+
import uuid
9+
10+
11+
def create_boxes_ndjson(datarow_id: str, schema_id: str, top: float, left: float,
12+
bottom: float, right: float) -> Dict[str, Any]:
13+
"""
14+
* https://docs.labelbox.com/data-model/en/index-en#bounding-box
15+
16+
Args:
17+
datarow_id (str): id of the data_row (in this case image) to add this annotation to
18+
schema_id (str): id of the bbox tool in the current ontology
19+
top, left, bottom, right (int): pixel coordinates of the bbox
20+
Returns:
21+
ndjson representation of a bounding box
22+
"""
23+
return {
24+
"uuid": str(uuid.uuid4()),
25+
"schemaId": schema_id,
26+
"dataRow": {
27+
"id": datarow_id
28+
},
29+
"bbox": {
30+
"top": int(top),
31+
"left": int(left),
32+
"height": int(bottom - top),
33+
"width": int(right - left)
34+
}
35+
}
36+
37+
38+
def create_polygon_ndjson(datarow_id: str, schema_id: str,
39+
segmentation_mask: np.ndarray) -> Dict[str, Any]:
40+
"""
41+
* https://docs.labelbox.com/data-model/en/index-en#polygon
42+
43+
Args:
44+
datarow_id (str): id of the data_row (in this case image) to add this annotation to
45+
schema_id (str): id of the bbox tool in the current ontology
46+
segmentation_mask (np.ndarray): Segmentation mask of size (image_h, image_w)
47+
- Seg mask is turned into a polygon since polygons aren't directly inferred.
48+
Returns:
49+
ndjson representation of a polygon
50+
"""
51+
contours = measure.find_contours(segmentation_mask, 0.5)
52+
#Note that complex polygons could break.
53+
pts = contours[0].astype(np.int32)
54+
pts = np.roll(pts, 1, axis=-1)
55+
pts = [{'x': int(x), 'y': int(y)} for x, y in pts]
56+
return {
57+
"uuid": str(uuid.uuid4()),
58+
"schemaId": schema_id,
59+
"dataRow": {
60+
"id": datarow_id
61+
},
62+
"polygon": pts
63+
}
64+
65+
66+
def create_mask_ndjson(client: Client, datarow_id: str, schema_id: str,
67+
segmentation_mask: np.ndarray, color: Tuple[int, int,
68+
int]) -> Dict[str, Any]:
69+
"""
70+
Creates a mask for each object in the image
71+
* https://docs.labelbox.com/data-model/en/index-en#segmentation-mask
72+
73+
Args:
74+
client (labelbox.Client): labelbox client used for uploading seg mask to google cloud storage
75+
datarow_id (str): id of the data_row (in this case image) to add this annotation to
76+
schema_id (str): id of the segmentation tool in the current ontology
77+
segmentation_mask is a segmentation mask of size (image_h, image_w)
78+
color ( Tuple[int,int,int]): rgb color to convert binary mask into 3D colorized mask
79+
Return:
80+
ndjson representation of a segmentation mask
81+
"""
82+
83+
colorize = np.concatenate(([segmentation_mask[..., np.newaxis] * c for c in color]),
84+
axis=2)
85+
img_bytes = BytesIO()
86+
Image.fromarray(colorize).save(img_bytes, format="PNG")
87+
#* Use your own signed urls so that you can resign the data
88+
#* This is just to make the demo work
89+
url = client.upload_data(content=img_bytes.getvalue(), sign=True)
90+
return {
91+
"uuid": str(uuid.uuid4()),
92+
"schemaId": schema_id,
93+
"dataRow": {
94+
"id": datarow_id
95+
},
96+
"mask": {
97+
"instanceURI": url,
98+
"colorRGB": color
99+
}
100+
}
101+
102+
103+
def create_point_ndjson(datarow_id: str, schema_id: str, top: float, left: float,
104+
bottom: float, right: float) -> Dict[str, Any]:
105+
"""
106+
* https://docs.labelbox.com/data-model/en/index-en#point
107+
108+
Args:
109+
datarow_id (str): id of the data_row (in this case image) to add this annotation to
110+
schema_id (str): id of the point tool in the current ontology
111+
t, l, b, r (int): top, left, bottom, right pixel coordinates of the bbox
112+
- The model doesn't directly predict points, so we grab the centroid of the predicted bounding box
113+
Returns:
114+
ndjson representation of a polygon
115+
"""
116+
return {
117+
"uuid": str(uuid.uuid4()),
118+
"schemaId": schema_id,
119+
"dataRow": {
120+
"id": datarow_id
121+
},
122+
"point": {
123+
"x": int((left + right) / 2.),
124+
"y": int((top + bottom) / 2.),
125+
}
126+
}

labelbox/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
__version__ = "2.5.6"
33

44
from labelbox.client import Client
5+
from labelbox.schema.model import Model
56
from labelbox.schema.bulk_import_request import BulkImportRequest
7+
from labelbox.schema.annotation_import import MALPredictionImport, MEAPredictionImport
68
from labelbox.schema.project import Project
79
from labelbox.schema.dataset import Dataset
810
from labelbox.schema.data_row import DataRow
@@ -19,4 +21,3 @@
1921
from labelbox.schema.role import Role, ProjectRole
2022
from labelbox.schema.invite import Invite, InviteLimit
2123
from labelbox.schema.model_run import ModelRun
22-
from labelbox.schema.model import Model

labelbox/client.py

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,13 @@ def __init__(self,
7979

8080
@retry.Retry(predicate=retry.if_exception_type(
8181
labelbox.exceptions.InternalServerError))
82-
def execute(self, query, params=None, timeout=30.0, experimental=False):
82+
def execute(self,
83+
query=None,
84+
params=None,
85+
data=None,
86+
files=None,
87+
timeout=30.0,
88+
experimental=False):
8389
""" Sends a request to the server for the execution of the
8490
given query.
8591
@@ -89,6 +95,8 @@ def execute(self, query, params=None, timeout=30.0, experimental=False):
8995
Args:
9096
query (str): The query to execute.
9197
params (dict): Query parameters referenced within the query.
98+
data (str): json string containing the query to execute
99+
files (dict): file arguments for request
92100
timeout (float): Max allowed time for query execution,
93101
in seconds.
94102
Returns:
@@ -107,8 +115,9 @@ def execute(self, query, params=None, timeout=30.0, experimental=False):
107115
most likely due to connection issues.
108116
labelbox.exceptions.LabelboxError: If an unknown error of any
109117
kind occurred.
118+
ValueError: If query and data are both None.
110119
"""
111-
logger.debug("Query: %s, params: %r", query, params)
120+
logger.debug("Query: %s, params: %r, data %r", query, params, data)
112121

113122
# Convert datetimes to UTC strings.
114123
def convert_value(value):
@@ -117,19 +126,35 @@ def convert_value(value):
117126
value = value.strftime("%Y-%m-%dT%H:%M:%SZ")
118127
return value
119128

120-
if params is not None:
121-
params = {
122-
key: convert_value(value) for key, value in params.items()
123-
}
124-
125-
data = json.dumps({'query': query, 'variables': params}).encode('utf-8')
126-
129+
if query is not None:
130+
if params is not None:
131+
params = {
132+
key: convert_value(value) for key, value in params.items()
133+
}
134+
data = json.dumps({
135+
'query': query,
136+
'variables': params
137+
}).encode('utf-8')
138+
elif data is None:
139+
raise ValueError("query and data cannot both be none")
127140
try:
128-
response = requests.post(self.endpoint.replace('/graphql', '/_gql')
129-
if experimental else self.endpoint,
130-
data=data,
131-
headers=self.headers,
132-
timeout=timeout)
141+
request = {
142+
'url':
143+
self.endpoint.replace('/graphql', '/_gql')
144+
if experimental else self.endpoint,
145+
'data':
146+
data,
147+
'headers':
148+
self.headers,
149+
'timeout':
150+
timeout
151+
}
152+
if files:
153+
request.update({'files': files})
154+
request['headers'] = {
155+
'Authorization': self.headers['Authorization']
156+
}
157+
response = requests.post(**request)
133158
logger.debug("Response: %s", response.text)
134159
except requests.exceptions.Timeout as e:
135160
raise labelbox.exceptions.TimeoutError(str(e))
@@ -548,4 +573,14 @@ def create_model(self, name, ontology_id):
548573
InvalidAttributeError: If the Model type does not contain
549574
any of the attribute names given in kwargs.
550575
"""
551-
return self._create(Model, {"name": name, "ontology_id": ontology_id})
576+
query_str = """mutation createModelPyApi($name: String!, $ontologyId: ID!){
577+
createModel(data: {name : $name, ontologyId : $ontologyId}){
578+
%s
579+
}
580+
}""" % query.results_query_part(Model)
581+
582+
result = self.execute(query_str, {
583+
"name": name,
584+
"ontologyId": ontology_id
585+
})
586+
return Model(self, result['createModel'])

labelbox/orm/db_object.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def __init__(self, client, field_values):
4848
if relationship.cache and value is None:
4949
raise KeyError(
5050
f"Expected field values for {relationship.name}")
51-
5251
setattr(self, relationship.name,
5352
RelationshipManager(self, relationship, value))
5453

labelbox/schema/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import labelbox.schema.asset_metadata
22
import labelbox.schema.bulk_import_request
3+
import labelbox.schema.annotation_import
34
import labelbox.schema.benchmark
45
import labelbox.schema.data_row
56
import labelbox.schema.dataset

0 commit comments

Comments
 (0)