diff --git a/CHANGELOG.md b/CHANGELOG.md index de9ee6ff4..f0a955532 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ ## Added * All imports are available via `import labelbox as lb` and `import labelbox.types as lb_types`. * Attachment_name support to create_attachment() +* New method `Project.task_queues()` to obtain the task queues for a project. +* New method `Project.move_data_rows_to_task_queue()` for moving data rows to a specified task queue. ## Changed * `LabelImport.create_from_objects()`, `MALPredictionImport.create_from_objects()`, `MEAPredictionImport.create_from_objects()`, `Project.upload_annotations()`, `ModelRun.add_predictions()` now support Python Types for annotations. diff --git a/docs/source/index.rst b/docs/source/index.rst index 694f2b8b9..a28843d22 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -118,6 +118,12 @@ Task :members: :show-inheritance: +Task Queue +--------------------------- +.. automodule:: labelbox.schema.task_queue + :members: + :show-inheritance: + User --------------------------- diff --git a/labelbox/__init__.py b/labelbox/__init__.py index 47938fe2f..d9acd8374 100644 --- a/labelbox/__init__.py +++ b/labelbox/__init__.py @@ -29,3 +29,4 @@ from labelbox.schema.media_type import MediaType from labelbox.schema.slice import Slice, CatalogSlice from labelbox.schema.queue_mode import QueueMode +from labelbox.schema.task_queue import TaskQueue diff --git a/labelbox/orm/model.py b/labelbox/orm/model.py index 02ba5288b..f4f09d8c8 100644 --- a/labelbox/orm/model.py +++ b/labelbox/orm/model.py @@ -378,6 +378,7 @@ class Entity(metaclass=EntityMeta): Project: Type[labelbox.Project] Batch: Type[labelbox.Batch] CatalogSlice: Type[labelbox.CatalogSlice] + TaskQueue: Type[labelbox.TaskQueue] @classmethod def _attributes_of_type(cls, attr_type): diff --git a/labelbox/schema/project.py b/labelbox/schema/project.py index 05f8f4646..f002982d1 100644 --- a/labelbox/schema/project.py +++ b/labelbox/schema/project.py @@ -26,6 +26,7 @@ from labelbox.schema.resource_tag import ResourceTag from labelbox.schema.task import Task from labelbox.schema.user import User +from labelbox.schema.task_queue import TaskQueue if TYPE_CHECKING: from labelbox import BulkImportRequest @@ -69,6 +70,7 @@ class Project(DbObject, Updateable, Deletable): webhooks (Relationship): `ToMany` relationship to Webhook benchmarks (Relationship): `ToMany` relationship to Benchmark ontology (Relationship): `ToOne` relationship to Ontology + task_queues (Relationship): `ToMany` relationship to TaskQueue """ name = Field.String("name") @@ -794,54 +796,34 @@ def _create_batch_async(self, task_id = res['taskId'] - timeout_seconds = 600 - sleep_time = 2 - get_task_query_str = """query %s($taskId: ID!) { - task(where: {id: $taskId}) { - status + task = self._wait_for_task(task_id) + if task.status != "COMPLETE": + raise LabelboxError(f"Batch was not created successfully: " + + json.dumps(task.errors)) + + # obtain batch entity to return + get_batch_str = """query %s($projectId: ID!, $batchId: ID!) { + project(where: {id: $projectId}) { + batches(where: {id: $batchId}) { + nodes { + %s + } + } } } - """ % "getTaskPyApi" + """ % ("getProjectBatchPyApi", + query.results_query_part(Entity.Batch)) - while True: - task_status = self.client.execute( - get_task_query_str, {'taskId': task_id}, - experimental=True)['task']['status'] - - if task_status == "COMPLETE": - # obtain batch entity to return - get_batch_str = """query %s($projectId: ID!, $batchId: ID!) { - project(where: {id: $projectId}) { - batches(where: {id: $batchId}) { - nodes { - %s - } - } - } - } - """ % ("getProjectBatchPyApi", - query.results_query_part(Entity.Batch)) - - batch = self.client.execute( - get_batch_str, { - "projectId": self.uid, - "batchId": batch_id - }, - timeout=180.0, - experimental=True)["project"]["batches"]["nodes"][0] - - # TODO async endpoints currently do not provide failed_data_row_ids in response - return Entity.Batch(self.client, self.uid, batch) - elif task_status == "IN_PROGRESS": - timeout_seconds -= sleep_time - if timeout_seconds <= 0: - raise LabelboxError( - f"Timed out while waiting for batch to be created.") - logger.debug("Creating batch, waiting for server...", self.uid) - time.sleep(sleep_time) - continue - else: - raise LabelboxError(f"Batch was not created successfully.") + batch = self.client.execute( + get_batch_str, { + "projectId": self.uid, + "batchId": batch_id + }, + timeout=180.0, + experimental=True)["project"]["batches"]["nodes"][0] + + # TODO async endpoints currently do not provide failed_data_row_ids in response + return Entity.Batch(self.client, self.uid, batch) def _update_queue_mode(self, mode: "QueueMode") -> "QueueMode": """ @@ -1127,6 +1109,81 @@ def batches(self) -> PaginatedCollection: cursor_path=['project', 'batches', 'pageInfo', 'endCursor'], experimental=True) + def task_queues(self) -> List[TaskQueue]: + """ Fetch all task queues that belong to this project + + Returns: + A `List of `TaskQueue`s + """ + query_str = """query GetProjectTaskQueuesPyApi($projectId: ID!) { + project(where: {id: $projectId}) { + taskQueues { + %s + } + } + } + """ % (query.results_query_part(Entity.TaskQueue)) + + task_queue_values = self.client.execute( + query_str, {"projectId": self.uid}, + timeout=180.0, + experimental=True)["project"]["taskQueues"] + + return [ + Entity.TaskQueue(self.client, field_values) + for field_values in task_queue_values + ] + + def move_data_rows_to_task_queue(self, data_row_ids: List[str], + task_queue_id: str): + """ + + Moves data rows to the specified task queue. + + Args: + data_row_ids: a list of data row ids to be moved + task_queue_id: the task queue id to be moved to, or None to specify the "Done" queue + + Returns: + None if successful, or a raised error on failure + + """ + method = "createBulkAddRowsToQueueTask" + query_str = """mutation AddDataRowsToTaskQueueAsyncPyApi( + $projectId: ID! + $queueId: ID + $dataRowIds: [ID!]! + ) { + project(where: { id: $projectId }) { + %s( + data: { queueId: $queueId, dataRowIds: $dataRowIds } + ) { + taskId + } + } + } + """ % method + + task_id = self.client.execute( + query_str, { + "projectId": self.uid, + "queueId": task_queue_id, + "dataRowIds": data_row_ids + }, + timeout=180.0, + experimental=True)["project"][method]["taskId"] + + task = self._wait_for_task(task_id) + if task.status != "COMPLETE": + raise LabelboxError(f"Data rows were not moved successfully: " + + json.dumps(task.errors)) + + def _wait_for_task(self, task_id: str) -> Task: + task = Task.get_task(self.client, task_id) + task.wait_till_done() + + return task + def upload_annotations( self, name: str, diff --git a/labelbox/schema/task.py b/labelbox/schema/task.py index 31f3390ee..88dce6b4a 100644 --- a/labelbox/schema/task.py +++ b/labelbox/schema/task.py @@ -1,3 +1,4 @@ +import json import logging import requests import time @@ -6,7 +7,7 @@ from labelbox.exceptions import ResourceNotFoundError from labelbox.orm.db_object import DbObject -from labelbox.orm.model import Field, Relationship +from labelbox.orm.model import Field, Relationship, Entity if TYPE_CHECKING: from labelbox import User @@ -55,7 +56,7 @@ def refresh(self) -> None: for field in self.fields(): setattr(self, field.name, getattr(tasks[0], field.name)) - def wait_till_done(self, timeout_seconds=300) -> None: + def wait_till_done(self, timeout_seconds: int = 300) -> None: """ Waits until the task is completed. Periodically queries the server to update the task attributes. @@ -83,9 +84,16 @@ def wait_till_done(self, timeout_seconds=300) -> None: def errors(self) -> Optional[Dict[str, Any]]: """ Fetch the error associated with an import task. """ + if self.type == "add-data-rows-to-batch" or self.type == "send-to-task-queue": + if self.status == "FAILED": + # for these tasks, the error is embedded in the result itself + return json.loads(self.result_url) + return None + # TODO: We should handle error messages for export v2 tasks in the future. if self.name != 'JSON Import': return None + if self.status == "FAILED": result = self._fetch_remote_json() return result["error"] @@ -153,3 +161,17 @@ def download_result(): "Job status still in `IN_PROGRESS`. The result is not available. Call task.wait_till_done() with a larger timeout or contact support." ) return download_result() + + @staticmethod + def get_task(client, task_id): + user: User = client.get_user() + tasks: List[Task] = list( + user.created_tasks(where=Entity.Task.uid == task_id)) + # Cache user in a private variable as the relationship can't be + # resolved due to server-side limitations (see Task.created_by) + # for more info. + if len(tasks) != 1: + raise ResourceNotFoundError(Entity.Task, {task_id: task_id}) + task: Task = tasks[0] + task._user = user + return task diff --git a/labelbox/schema/task_queue.py b/labelbox/schema/task_queue.py new file mode 100644 index 000000000..6e6a84d00 --- /dev/null +++ b/labelbox/schema/task_queue.py @@ -0,0 +1,28 @@ +from labelbox.orm.db_object import DbObject +from labelbox.orm.model import Field + + +class TaskQueue(DbObject): + """ + a task queue + + Attributes + name + description + queue_type + data_row_count + + Relationships + project + organization + pass_queue + fail_queue + """ + + name = Field.String("name") + description = Field.String("description") + queue_type = Field.String("queue_type") + data_row_count = Field.Int("data_row_count") + + def __init__(self, client, *args, **kwargs): + super().__init__(client, *args, **kwargs) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index a86d25bc6..fe0e387d0 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -311,18 +311,39 @@ def configured_project_with_label(client, rand_gen, image_url, project, dataset, One label is already created and yielded when using fixture """ project.datasets.connect(dataset) - editor = list( - project.client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] - ontology_builder = OntologyBuilder(tools=[ - Tool(tool=Tool.Type.BBOX, name="test-bbox-class"), - ]) - project.setup(editor, ontology_builder.asdict()) - # TODO: ontology may not be synchronous after setup. remove sleep when api is more consistent - time.sleep(2) + ontology = _setup_ontology(project) + label = _create_label(project, datarow, ontology, wait_for_label_processing) + + yield [project, dataset, datarow, label] + + for label in project.labels(): + label.delete() + + +@pytest.fixture +def configured_batch_project_with_label(client, rand_gen, image_url, + batch_project, dataset, datarow, + wait_for_label_processing): + """Project with a batch having one datarow + Project contains an ontology with 1 bbox tool + Additionally includes a create_label method for any needed extra labels + One label is already created and yielded when using fixture + """ + data_rows = [dr.uid for dr in list(dataset.data_rows())] + batch_project.create_batch("test-batch", data_rows) + + ontology = _setup_ontology(batch_project) + label = _create_label(batch_project, datarow, ontology, + wait_for_label_processing) + + yield [batch_project, dataset, datarow, label] + + for label in batch_project.labels(): + label.delete() + - ontology = ontology_builder.from_project(project) +def _create_label(project, datarow, ontology, wait_for_label_processing): predictions = [{ "uuid": str(uuid.uuid4()), "schemaId": ontology.tools[0].feature_schema_id, @@ -342,7 +363,8 @@ def create_label(): Creates a LabelImport task which will create a label """ upload_task = LabelImport.create_from_objects( - client, project.uid, f'label-import-{uuid.uuid4()}', predictions) + project.client, project.uid, f'label-import-{uuid.uuid4()}', + predictions) upload_task.wait_until_done(sleep_time_seconds=5) assert upload_task.state == AnnotationImportState.FINISHED, "Label Import did not finish" assert len( @@ -352,11 +374,20 @@ def create_label(): project.create_label = create_label project.create_label() label = wait_for_label_processing(project)[0] + return label - yield [project, dataset, datarow, label] - for label in project.labels(): - label.delete() +def _setup_ontology(project): + editor = list( + project.client.get_labeling_frontends( + where=LabelingFrontend.name == "editor"))[0] + ontology_builder = OntologyBuilder(tools=[ + Tool(tool=Tool.Type.BBOX, name="test-bbox-class"), + ]) + project.setup(editor, ontology_builder.asdict()) + # TODO: ontology may not be synchronous after setup. remove sleep when api is more consistent + time.sleep(2) + return ontology_builder.from_project(project) @pytest.fixture diff --git a/tests/integration/test_task_queue.py b/tests/integration/test_task_queue.py new file mode 100644 index 000000000..d8560355c --- /dev/null +++ b/tests/integration/test_task_queue.py @@ -0,0 +1,37 @@ +import time + +from labelbox import Project + + +def test_get_task_queue(batch_project: Project): + task_queues = batch_project.task_queues() + assert len(task_queues) == 3 + review_queue = next( + tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE") + assert review_queue + + +def test_move_to_task(configured_batch_project_with_label: Project): + project, _, data_row, label = configured_batch_project_with_label + task_queues = project.task_queues() + + review_queue = next( + tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE") + project.move_data_rows_to_task_queue([data_row.uid], review_queue.uid) + + timeout_seconds = 30 + sleep_time = 2 + while True: + task_queues = project.task_queues() + review_queue = next( + tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE") + + if review_queue.data_row_count == 1: + break + + if timeout_seconds <= 0: + raise AssertionError( + "Timed out expecting data_row_count of 1 in the review queue") + + timeout_seconds -= sleep_time + time.sleep(sleep_time)