diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 6a24eb8f9..1e1484503 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -53,6 +53,11 @@ jobs: # TODO: create a staging environment (develop) # we only test against prod right now because the merges are right into # the main branch which is develop right now - LABELBOX_TEST_ENVIRON: "PROD" + LABELBOX_TEST_ENVIRON: "prod" + # + # randall+staging-python@labelbox.com + #LABELBOX_TEST_API_KEY: ${{ secrets.STAGING_LABELBOX_API_KEY }} + #LABELBOX_TEST_ENDPOINT: "https://staging-api.labelbox.com/graphql" + #LABELBOX_TEST_ENVIRON: "staging" run: | - tox -- -svv \ No newline at end of file + tox -- -svv diff --git a/labelbox/__init__.py b/labelbox/__init__.py index bf5b3c12b..115023bd4 100644 --- a/labelbox/__init__.py +++ b/labelbox/__init__.py @@ -1,6 +1,7 @@ name = "labelbox" from labelbox.client import Client +from labelbox.schema.bulk_import_request import BulkImportRequest from labelbox.schema.project import Project from labelbox.schema.dataset import Dataset from labelbox.schema.data_row import DataRow diff --git a/labelbox/schema/__init__.py b/labelbox/schema/__init__.py index 580f40f21..eadb49ab8 100644 --- a/labelbox/schema/__init__.py +++ b/labelbox/schema/__init__.py @@ -1,4 +1,5 @@ import labelbox.schema.asset_metadata +import labelbox.schema.bulk_import_request import labelbox.schema.benchmark import labelbox.schema.data_row import labelbox.schema.dataset diff --git a/labelbox/schema/bulk_import_request.py b/labelbox/schema/bulk_import_request.py index 8bb861c59..ef054cb48 100644 --- a/labelbox/schema/bulk_import_request.py +++ b/labelbox/schema/bulk_import_request.py @@ -11,10 +11,8 @@ import ndjson import requests +from labelbox import utils import labelbox.exceptions -from labelbox import Client -from labelbox import Project -from labelbox import User from labelbox.orm import query from labelbox.orm.db_object import DbObject from labelbox.orm.model import Field @@ -25,18 +23,143 @@ logger = logging.getLogger(__name__) +def _make_file_name(project_id: str, name: str) -> str: + return f"{project_id}__{name}.ndjson" + + +# TODO(gszpak): move it to client.py +def _make_request_data(project_id: str, name: str, content_length: int, + file_name: str) -> dict: + query_str = """mutation createBulkImportRequestFromFilePyApi( + $projectId: ID!, $name: String!, $file: Upload!, $contentLength: Int!) { + createBulkImportRequest(data: { + projectId: $projectId, + name: $name, + filePayload: { + file: $file, + contentLength: $contentLength + } + }) { + %s + } + } + """ % query.results_query_part(BulkImportRequest) + variables = { + "projectId": project_id, + "name": name, + "file": None, + "contentLength": content_length + } + operations = json.dumps({"variables": variables, "query": query_str}) + + return { + "operations": operations, + "map": (None, json.dumps({file_name: ["variables.file"]})) + } + + +# TODO(gszpak): move it to client.py +def _send_create_file_command( + client, request_data: dict, file_name: str, + file_data: Tuple[str, Union[bytes, BinaryIO], str]) -> dict: + response = requests.post( + client.endpoint, + headers={"authorization": "Bearer %s" % client.api_key}, + data=request_data, + files={file_name: file_data}) + + try: + response_json = response.json() + except ValueError: + raise labelbox.exceptions.LabelboxError( + "Failed to parse response as JSON: %s" % response.text) + + response_data = response_json.get("data", None) + if response_data is None: + raise labelbox.exceptions.LabelboxError( + "Failed to upload, message: %s" % response_json.get("errors", None)) + + if not response_data.get("createBulkImportRequest", None): + raise labelbox.exceptions.LabelboxError( + "Failed to create BulkImportRequest, message: %s" % + response_json.get("errors", None) or + response_data.get("error", None)) + + return response_data + + class BulkImportRequest(DbObject): - project = Relationship.ToOne("Project") name = Field.String("name") - created_at = Field.DateTime("created_at") - created_by = Relationship.ToOne("User", False, "created_by") + state = Field.Enum(BulkImportRequestState, "state") input_file_url = Field.String("input_file_url") error_file_url = Field.String("error_file_url") status_file_url = Field.String("status_file_url") - state = Field.Enum(BulkImportRequestState, "state") + created_at = Field.DateTime("created_at") + + project = Relationship.ToOne("Project") + created_by = Relationship.ToOne("User", False, "created_by") + + def refresh(self) -> None: + """ + Synchronizes values of all fields with the database. + """ + query_str, params = query.get_single(BulkImportRequest, self.uid) + res = self.client.execute(query_str, params) + res = res[utils.camel_case(BulkImportRequest.type_name())] + self._set_field_values(res) + + def wait_until_done(self, sleep_time_seconds: int = 30) -> None: + """ + Blocks until the BulkImportRequest.state changes either to + `BulkImportRequestState.FINISHED` or `BulkImportRequestState.FAILED`, + periodically refreshing object's state. + + Args: + sleep_time_seconds (str): a time to block between subsequent API calls + """ + while self.state == BulkImportRequestState.RUNNING: + logger.info(f"Sleeping for {sleep_time_seconds} seconds...") + time.sleep(sleep_time_seconds) + self.__exponential_backoff_refresh() + + @backoff.on_exception( + backoff.expo, + (labelbox.exceptions.ApiLimitError, labelbox.exceptions.TimeoutError, + labelbox.exceptions.NetworkError), + max_tries=10, + jitter=None) + def __exponential_backoff_refresh(self) -> None: + self.refresh() @classmethod - def create_from_url(cls, client: Client, project_id: str, name: str, + def from_name(cls, client, project_id: str, + name: str) -> 'BulkImportRequest': + """ Fetches existing BulkImportRequest. + + Args: + client (Client): a Labelbox client + project_id (str): BulkImportRequest's project id + name (str): name of BulkImportRequest + Returns: + BulkImportRequest object + + """ + query_str = """query getBulkImportRequestPyApi( + $projectId: ID!, $name: String!) { + bulkImportRequest(where: { + projectId: $projectId, + name: $name + }) { + %s + } + } + """ % query.results_query_part(cls) + params = {"projectId": project_id, "name": name} + response = client.execute(query_str, params=params) + return cls(client, response['bulkImportRequest']) + + @classmethod + def create_from_url(cls, client, project_id: str, name: str, url: str) -> 'BulkImportRequest': """ Creates a BulkImportRequest from a publicly accessible URL @@ -60,14 +183,16 @@ def create_from_url(cls, client: Client, project_id: str, name: str, %s } } - """ % cls.__build_results_query_part() + """ % query.results_query_part(cls) params = {"projectId": project_id, "name": name, "fileUrl": url} bulk_import_request_response = client.execute(query_str, params=params) - return cls.__build_bulk_import_request_from_result( - client, bulk_import_request_response["createBulkImportRequest"]) + print('query_str', query_str, params) + print('response data', bulk_import_request_response) + return cls(client, + bulk_import_request_response["createBulkImportRequest"]) @classmethod - def create_from_objects(cls, client: Client, project_id: str, name: str, + def create_from_objects(cls, client, project_id: str, name: str, predictions: Iterable[dict]) -> 'BulkImportRequest': """ Creates a BulkImportRequest from an iterable of dictionaries conforming to @@ -96,18 +221,20 @@ def create_from_objects(cls, client: Client, project_id: str, name: str, """ data_str = ndjson.dumps(predictions) data = data_str.encode('utf-8') - file_name = cls.__make_file_name(project_id, name) - request_data = cls.__make_request_data(project_id, name, len(data_str), - file_name) + file_name = _make_file_name(project_id, name) + request_data = _make_request_data(project_id, name, len(data_str), + file_name) file_data = (file_name, data, NDJSON_MIME_TYPE) - response_data = cls.__send_create_file_command(client, request_data, - file_name, file_data) - return cls.__build_bulk_import_request_from_result( - client, response_data["createBulkImportRequest"]) + response_data = _send_create_file_command(client, + request_data=request_data, + file_name=file_name, + file_data=file_data) + + return cls(client, response_data["createBulkImportRequest"]) @classmethod def create_from_local_file(cls, - client: Client, + client, project_id: str, name: str, file: Path, @@ -124,197 +251,28 @@ def create_from_local_file(cls, if `file` is a valid ndjson file Returns: BulkImportRequest object + """ - file_name = cls.__make_file_name(project_id, name) + file_name = _make_file_name(project_id, name) content_length = file.stat().st_size - request_data = cls.__make_request_data(project_id, name, content_length, - file_name) + request_data = _make_request_data(project_id, name, content_length, + file_name) + with file.open('rb') as f: - file_data: Tuple[str, Union[bytes, BinaryIO], str] if validate_file: - data = f.read() + reader = ndjson.reader(f) + # ensure that the underlying json load call is valid + # https://github.com/rhgrant10/ndjson/blob/ff2f03c56b21f28f7271b27da35ca4a8bf9a05d0/ndjson/api.py#L53 + # by iterating through the file so we only store + # each line in memory rather than the entire file try: - ndjson.loads(data) + for line in reader: + pass except ValueError: raise ValueError(f"{file} is not a valid ndjson file") - file_data = (file.name, data, NDJSON_MIME_TYPE) - else: - file_data = (file.name, f, NDJSON_MIME_TYPE) - response_data = cls.__send_create_file_command( - client, request_data, file_name, file_data) - return cls.__build_bulk_import_request_from_result( - client, response_data["createBulkImportRequest"]) - - # TODO(gszpak): building query body should be handled by the client - @classmethod - def get(cls, client: Client, project_id: str, - name: str) -> 'BulkImportRequest': - """ - Fetches existing BulkImportRequest. - - Args: - client (Client): a Labelbox client - project_id (str): BulkImportRequest's project id - name (str): name of BulkImportRequest - Returns: - BulkImportRequest object - """ - query_str = """query getBulkImportRequestPyApi( - $projectId: ID!, $name: String!) { - bulkImportRequest(where: { - projectId: $projectId, - name: $name - }) { - %s - } - } - """ % cls.__build_results_query_part() - params = {"projectId": project_id, "name": name} - bulk_import_request_kwargs = \ - client.execute(query_str, params=params).get("bulkImportRequest") - if bulk_import_request_kwargs is None: - raise labelbox.exceptions.ResourceNotFoundError( - BulkImportRequest, { - "projectId": project_id, - "name": name - }) - return cls.__build_bulk_import_request_from_result( - client, bulk_import_request_kwargs) - - def refresh(self) -> None: - """ - Synchronizes values of all fields with the database. - """ - bulk_import_request = self.get(self.client, - self.project().uid, self.name) - for field in self.fields(): - setattr(self, field.name, getattr(bulk_import_request, field.name)) - - def wait_until_done(self, sleep_time_seconds: int = 30) -> None: - """ - Blocks until the BulkImportRequest.state changes either to - `BulkImportRequestState.FINISHED` or `BulkImportRequestState.FAILED`, - periodically refreshing object's state. - - Args: - sleep_time_seconds (str): a time to block between subsequent API calls - """ - while self.state == BulkImportRequestState.RUNNING: - logger.info(f"Sleeping for {sleep_time_seconds} seconds...") - time.sleep(sleep_time_seconds) - self.__exponential_backoff_refresh() - - @backoff.on_exception( - backoff.expo, - (labelbox.exceptions.ApiLimitError, labelbox.exceptions.TimeoutError, - labelbox.exceptions.NetworkError), - max_tries=10, - jitter=None) - def __exponential_backoff_refresh(self) -> None: - self.refresh() - - # TODO(gszpak): project() and created_by() methods - # TODO(gszpak): are hacky ways to eagerly load the relationships - def project(self): # type: ignore - if self.__project is not None: - return self.__project - return None - - def created_by(self): # type: ignore - if self.__user is not None: - return self.__user - return None - - @classmethod - def __make_file_name(cls, project_id: str, name: str) -> str: - return f"{project_id}__{name}.ndjson" - - # TODO(gszpak): move it to client.py - @classmethod - def __make_request_data(cls, project_id: str, name: str, - content_length: int, file_name: str) -> dict: - query_str = """mutation createBulkImportRequestFromFilePyApi( - $projectId: ID!, $name: String!, $file: Upload!, $contentLength: Int!) { - createBulkImportRequest(data: { - projectId: $projectId, - name: $name, - filePayload: { - file: $file, - contentLength: $contentLength - } - }) { - %s - } - } - """ % cls.__build_results_query_part() - variables = { - "projectId": project_id, - "name": name, - "file": None, - "contentLength": content_length - } - operations = json.dumps({"variables": variables, "query": query_str}) - - return { - "operations": operations, - "map": (None, json.dumps({file_name: ["variables.file"]})) - } - - # TODO(gszpak): move it to client.py - @classmethod - def __send_create_file_command( - cls, client: Client, request_data: dict, file_name: str, - file_data: Tuple[str, Union[bytes, BinaryIO], str]) -> dict: - response = requests.post( - client.endpoint, - headers={"authorization": "Bearer %s" % client.api_key}, - data=request_data, - files={file_name: file_data}) - - try: - response_json = response.json() - except ValueError: - raise labelbox.exceptions.LabelboxError( - "Failed to parse response as JSON: %s" % response.text) - - response_data = response_json.get("data", None) - if response_data is None: - raise labelbox.exceptions.LabelboxError( - "Failed to upload, message: %s" % - response_json.get("errors", None)) - - if not response_data.get("createBulkImportRequest", None): - raise labelbox.exceptions.LabelboxError( - "Failed to create BulkImportRequest, message: %s" % - response_json.get("errors", None) or - response_data.get("error", None)) - - return response_data - - # TODO(gszpak): all the code below should be handled automatically by Relationship - @classmethod - def __build_results_query_part(cls) -> str: - return """ - project { - %s - } - createdBy { - %s - } - %s - """ % (query.results_query_part(Project), - query.results_query_part(User), - query.results_query_part(BulkImportRequest)) - - @classmethod - def __build_bulk_import_request_from_result( - cls, client: Client, result: dict) -> 'BulkImportRequest': - project = result.pop("project") - user = result.pop("createdBy") - bulk_import_request = BulkImportRequest(client, result) - if project is not None: - bulk_import_request.__project = Project( # type: ignore - client, project) - if user is not None: - bulk_import_request.__user = User(client, user) # type: ignore - return bulk_import_request + else: + f.seek(0) + file_data = (file.name, f, NDJSON_MIME_TYPE) + response_data = _send_create_file_command(client, request_data, + file_name, file_data) + return cls(client, response_data["createBulkImportRequest"]) diff --git a/labelbox/schema/project.py b/labelbox/schema/project.py index 5b7272924..d5bc9e5a7 100644 --- a/labelbox/schema/project.py +++ b/labelbox/schema/project.py @@ -2,9 +2,13 @@ from datetime import datetime, timezone import json import logging +from pathlib import Path import time +from typing import Union, Iterable +from urllib.parse import urlparse from labelbox import utils +from labelbox.schema.bulk_import_request import BulkImportRequest from labelbox.exceptions import InvalidQueryError from labelbox.orm import query from labelbox.orm.db_object import DbObject, Updateable, Deletable @@ -113,7 +117,7 @@ def export_labels(self, timeout_seconds=60): payload, and returns the URL to that payload. Will only generate a new URL at a max frequency of 30 min. - + Args: timeout_seconds (float): Max waiting time, in seconds. Returns: @@ -352,6 +356,66 @@ def create_prediction(self, label, data_row, prediction_model=None): res = self.client.execute(query_str, params) return Prediction(self.client, res["createPrediction"]) + def upload_annotations( + self, + name: str, + annotations: Union[str, Union[str, Path], Iterable[dict]], + ) -> 'BulkImportRequest': # type: ignore + """ Uploads annotations to a project. + + Args: + name: name of the BulkImportRequest job + annotations: + url that is publicly accessible by Labelbox containing an + ndjson file + OR local path to an ndjson file + OR iterable of annotation rows + Returns: + BulkImportRequest + + """ + if isinstance(annotations, str): + + def _is_url_valid(url: str) -> bool: + """ Verifies that the given string is a valid url. + + Args: + url: string to be checked + Returns: + True if the given url is valid otherwise False + + """ + parsed = urlparse(url) + return bool(parsed.scheme) and bool(parsed.netloc) + + if _is_url_valid(annotations): + return BulkImportRequest.create_from_url( + client=self.client, + project_id=self.uid, + name=name, + url=annotations, + ) + else: + path = Path(annotations) + if not path.exists(): + raise FileNotFoundError( + f'{annotations} is not a valid url nor existing local file' + ) + return BulkImportRequest.create_from_local_file( + client=self.client, + project_id=self.uid, + name=name, + file=path, + validate_file=True, + ) + else: + return BulkImportRequest.create_from_objects( + client=self.client, + project_id=self.uid, + name=name, + predictions=annotations, # type: ignore + ) + class LabelingParameterOverride(DbObject): priority = Field.Int("priority") @@ -361,5 +425,5 @@ class LabelingParameterOverride(DbObject): LabelerPerformance = namedtuple( "LabelerPerformance", "user count seconds_per_label, total_time_labeling " "consensus average_benchmark_agreement last_activity_time") -LabelerPerformance.__doc__ = "Named tuple containing info about a labeler's " \ - "performance." +LabelerPerformance.__doc__ = ( + "Named tuple containing info about a labeler's performance.") diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 8babdd6b2..71ef87186 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -13,12 +13,36 @@ IMG_URL = "https://picsum.photos/200/300" +class Environ(Enum): + PROD = 'prod' + STAGING = 'staging' + + +@pytest.fixture +def environ() -> Environ: + """ + Checks environment variables for LABELBOX_ENVIRON to be + 'prod' or 'staging' + + Make sure to set LABELBOX_TEST_ENVIRON in .github/workflows/python-package.yaml + + """ + try: + return Environ(os.environ['LABELBOX_TEST_ENVIRON']) + # TODO: for some reason all other environs can be set but + # this one cannot in github actions + #return Environ.PROD + except KeyError: + raise Exception(f'Missing LABELBOX_TEST_ENVIRON in: {os.environ}') + + class IntegrationClient(Client): def __init__(self): - api_url = os.environ.get("LABELBOX_TEST_ENDPOINT", - "https://staging-api.labelbox.com/graphql") - super().__init__(os.environ["LABELBOX_TEST_API_KEY"], api_url) + api_url = os.environ["LABELBOX_TEST_ENDPOINT"] + api_key = os.environ["LABELBOX_TEST_API_KEY"] + #"https://staging-api.labelbox.com/graphql") + super().__init__(api_key, api_url) self.queries = [] @@ -79,29 +103,6 @@ def label_pack(project, rand_gen): dataset.delete() -class Environ(Enum): - PROD = 'prod' - STAGING = 'staging' - - -@pytest.fixture -def environ() -> Environ: - """ - Checks environment variables for LABELBOX_ENVIRON to be - 'prod' or 'staging' - - Make sure to set LABELBOX_TEST_ENVIRON in .github/workflows/python-package.yaml - - """ - try: - #return Environ(os.environ['LABELBOX_TEST_ENVIRON']) - # TODO: for some reason all other environs can be set but - # this one cannot in github actions - return Environ.PROD - except KeyError: - raise Exception(f'Missing LABELBOX_TEST_ENVIRON in: {os.environ}') - - @pytest.fixture def iframe_url(environ) -> str: return { diff --git a/tests/integration/test_bulk_import_request.py b/tests/integration/test_bulk_import_request.py index 8a4fde629..20ec7e095 100644 --- a/tests/integration/test_bulk_import_request.py +++ b/tests/integration/test_bulk_import_request.py @@ -43,12 +43,11 @@ }] -def test_create_from_url(client, project): +def test_create_from_url(project): name = str(uuid.uuid4()) url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" - bulk_import_request = BulkImportRequest.create_from_url( - client, project.uid, name, url) + bulk_import_request = project.upload_annotations(name=name, annotations=url) assert bulk_import_request.project() == project assert bulk_import_request.name == name @@ -58,11 +57,11 @@ def test_create_from_url(client, project): assert bulk_import_request.state == BulkImportRequestState.RUNNING -def test_create_from_objects(client, project): +def test_create_from_objects(project): name = str(uuid.uuid4()) - bulk_import_request = BulkImportRequest.create_from_objects( - client, project.uid, name, PREDICTIONS) + bulk_import_request = project.upload_annotations(name=name, + annotations=PREDICTIONS) assert bulk_import_request.project() == project assert bulk_import_request.name == name @@ -72,15 +71,15 @@ def test_create_from_objects(client, project): __assert_file_content(bulk_import_request.input_file_url) -def test_create_from_local_file(tmp_path, client, project): +def test_create_from_local_file(tmp_path, project): name = str(uuid.uuid4()) file_name = f"{name}.ndjson" file_path = tmp_path / file_name with file_path.open("w") as f: ndjson.dump(PREDICTIONS, f) - bulk_import_request = BulkImportRequest.create_from_local_file( - client, project.uid, name, file_path) + bulk_import_request = project.upload_annotations(name=name, + annotations=str(file_path)) assert bulk_import_request.project() == project assert bulk_import_request.name == name @@ -93,9 +92,11 @@ def test_create_from_local_file(tmp_path, client, project): def test_get(client, project): name = str(uuid.uuid4()) url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" - BulkImportRequest.create_from_url(client, project.uid, name, url) + project.upload_annotations(name=name, annotations=url) - bulk_import_request = BulkImportRequest.get(client, project.uid, name) + bulk_import_request = BulkImportRequest.from_name(client, + project_id=project.uid, + name=name) assert bulk_import_request.project() == project assert bulk_import_request.name == name @@ -105,23 +106,24 @@ def test_get(client, project): assert bulk_import_request.state == BulkImportRequestState.RUNNING -def test_validate_ndjson(tmp_path, client, project): +def test_validate_ndjson(tmp_path, project): file_name = f"broken.ndjson" file_path = tmp_path / file_name with file_path.open("w") as f: f.write("test") with pytest.raises(ValueError): - BulkImportRequest.create_from_local_file(client, project.uid, "name", - file_path) + project.upload_annotations(name="name", annotations=str(file_path)) @pytest.mark.slow -def test_wait_till_done(client, project): +def test_wait_till_done(project): name = str(uuid.uuid4()) url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" - bulk_import_request = BulkImportRequest.create_from_url( - client, project.uid, name, url) + bulk_import_request = project.upload_annotations( + name=name, + annotations=url, + ) bulk_import_request.wait_until_done()