Skip to content

Commit 95ccebc

Browse files
author
Grzegorz Szpak
committed
Added method to upload local ndjson with predictions to Labelbox' GCS
1 parent 3fe8529 commit 95ccebc

File tree

3 files changed

+36
-2
lines changed

3 files changed

+36
-2
lines changed

labelbox/schema/bulk_import_request.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
import json
2+
from pathlib import Path
23
from typing import Iterable
34

5+
import ndjson
6+
47
from labelbox import Client
8+
from labelbox.exceptions import LabelboxError
59
from labelbox.orm import query
610
from labelbox.orm.db_object import DbObject
711
from labelbox.orm.model import Field
812
from labelbox.orm.model import Relationship
913
from labelbox.schema.enums import BulkImportRequestState
14+
from labelbox.schema.enums import UploadedFileType
1015

1116

1217
class BulkImportRequest(DbObject):
@@ -25,7 +30,8 @@ def create(
2530
predictions: Iterable[dict]) -> 'BulkImportRequest':
2631
data_str = '\n'.join(json.dumps(prediction) for prediction in predictions)
2732
data = data_str.encode('utf-8')
28-
input_file_url = client.upload_data(data)
33+
input_file_url = client.upload_data(
34+
data, uploaded_file_type=UploadedFileType.PREDICTIONS)
2935
query_str = """
3036
mutation CreateBulkImportRequestPyApi {
3137
createBulkImportRequest(data: {
@@ -44,3 +50,27 @@ def create(
4450
)
4551
bulk_import_request_kwargs = client.execute(query_str)["createBulkImportRequest"]
4652
return BulkImportRequest(client, bulk_import_request_kwargs)
53+
54+
@staticmethod
55+
def upload_local_predictions_file(
56+
client: Client, local_predictions_file_path: Path) -> str:
57+
"""
58+
Uploads local NDJSON file containing predictions to Labelbox' object store
59+
and returns a URL of created file.
60+
61+
Args:
62+
client (Client): The Labelbox client
63+
local_predictions_file_path (str): local NDJSON file containing predictions
64+
Returns:
65+
A URL of uploaded NDJSON file
66+
Raises:
67+
LabelboxError: if local file is not a valid NDJSON file
68+
"""
69+
with local_predictions_file_path.open("rb") as f:
70+
try:
71+
data = ndjson.load(f)
72+
return client.upload_data(
73+
data, uploaded_file_type=UploadedFileType.PREDICTIONS)
74+
except ValueError:
75+
raise LabelboxError(
76+
f"File {local_predictions_file_path} is not a valid ndjson file")

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
requests==2.22.0
2+
ndjson==0.3.1

setup.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
long_description_content_type="text/markdown",
1616
url="https://labelbox.com",
1717
packages=setuptools.find_packages(),
18-
install_requires=["requests>=2.22.0"],
18+
install_requires=[
19+
"ndjson>=0.3.1",
20+
"requests>=2.22.0"
21+
],
1922
classifiers=[
2023
'Development Status :: 3 - Alpha',
2124
'License :: OSI Approved :: Apache Software License',

0 commit comments

Comments
 (0)