diff --git a/labelbox/schema/project.py b/labelbox/schema/project.py index 9618d72f5..14bb96230 100644 --- a/labelbox/schema/project.py +++ b/labelbox/schema/project.py @@ -1,27 +1,31 @@ +import enum import json -import time import logging +import time +import warnings from collections import namedtuple from datetime import datetime, timezone from pathlib import Path -from typing import Dict, Union, Iterable +from typing import Dict, Union, Iterable, List, Optional from urllib.parse import urlparse -import requests + import ndjson +import requests from labelbox import utils -from labelbox.schema.data_row import DataRow -from labelbox.orm import query -from labelbox.schema.bulk_import_request import BulkImportRequest from labelbox.exceptions import InvalidQueryError, LabelboxError +from labelbox.orm import query from labelbox.orm.db_object import DbObject, Updateable, Deletable from labelbox.orm.model import Entity, Field, Relationship from labelbox.pagination import PaginatedCollection +from labelbox.schema.bulk_import_request import BulkImportRequest +from labelbox.schema.data_row import DataRow try: datetime.fromisoformat # type: ignore[attr-defined] except AttributeError: from backports.datetime_fromisoformat import MonkeyPatch + MonkeyPatch.patch_fromisoformat() try: @@ -31,6 +35,19 @@ logger = logging.getLogger(__name__) +MAX_QUEUE_BATCH_SIZE = 1000 + + +class QueueMode(enum.Enum): + Batch = "Batch" + Dataset = "Dataset" + + +class QueueErrors(enum.Enum): + InvalidDataRowType = 'InvalidDataRowType' + AlreadyInProject = 'AlreadyInProject' + HasAttachedLabel = 'HasAttachedLabel' + class Project(DbObject, Updateable, Deletable): """ A Project is a container that includes a labeling frontend, an ontology, @@ -79,6 +96,14 @@ class Project(DbObject, Updateable, Deletable): benchmarks = Relationship.ToMany("Benchmark", False) ontology = Relationship.ToOne("Ontology", True) + def update(self, **kwargs): + + mode: Optional[QueueMode] = kwargs.pop("queue_mode", None) + if mode: + self._update_queue_mode(mode) + + return super().update(**kwargs) + def members(self): """ Fetch all current members for this project @@ -407,14 +432,14 @@ def setup(self, labeling_frontend, labeling_frontend_options): a.k.a. project ontology. If given a `dict` it will be converted to `str` using `json.dumps`. """ - organization = self.client.get_organization() + if not isinstance(labeling_frontend_options, str): labeling_frontend_options = json.dumps(labeling_frontend_options) self.labeling_frontend.connect(labeling_frontend) LFO = Entity.LabelingFrontendOptions - labeling_frontend_options = self.client._create( + self.client._create( LFO, { LFO.project: self, LFO.labeling_frontend: labeling_frontend, @@ -424,6 +449,103 @@ def setup(self, labeling_frontend, labeling_frontend_options): timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") self.update(setup_complete=timestamp) + def queue(self, data_row_ids: List[str]): + """Add Data Rows to the Project queue""" + + method = "submitBatchOfDataRows" + return self._post_batch(method, data_row_ids) + + def dequeue(self, data_row_ids: List[str]): + """Remove Data Rows from the Project queue""" + + method = "removeBatchOfDataRows" + return self._post_batch(method, data_row_ids) + + def _post_batch(self, method, data_row_ids: List[str]): + """Post batch methods""" + + if self.queue_mode() != QueueMode.Batch: + raise ValueError("Project must be in batch mode") + + if len(data_row_ids) > MAX_QUEUE_BATCH_SIZE: + raise ValueError( + f"Batch exceeds max size of {MAX_QUEUE_BATCH_SIZE}, consider breaking it into parts" + ) + + query = """mutation %sPyApi($projectId: ID!, $dataRowIds: [ID!]!) { + project(where: {id: $projectId}) { + %s(data: {dataRowIds: $dataRowIds}) { + dataRows { + dataRowId + error + } + } + } + } + """ % (method, method) + + res = self.client.execute(query, { + "projectId": self.uid, + "dataRowIds": data_row_ids + })["project"][method]["dataRows"] + + # TODO: figure out error messaging + if len(data_row_ids) == len(res): + raise ValueError("No dataRows were submitted successfully") + + if len(data_row_ids) > 0: + warnings.warn("Some Data Rows were not submitted successfully") + + return res + + def _update_queue_mode(self, mode: QueueMode) -> QueueMode: + + if self.queue_mode() == mode: + return mode + + if mode == QueueMode.Batch: + status = "ENABLED" + elif mode == QueueMode.Dataset: + status = "DISABLED" + else: + raise ValueError( + "Must provide either `BATCH` or `DATASET` as a mode") + + query_str = """mutation %s($projectId: ID!, $status: TagSetStatusInput!) { + project(where: {id: $projectId}) { + setTagSetStatus(input: {tagSetStatus: $status}) { + tagSetStatus + } + } + } + """ % "setTagSetStatusPyApi" + + self.client.execute(query_str, { + 'projectId': self.uid, + 'status': status + }) + + return mode + + def queue_mode(self): + + query_str = """query %s($projectId: ID!) { + project(where: {id: $projectId}) { + tagSetStatus + } + } + """ % "GetTagSetStatusPyApi" + + status = self.client.execute( + query_str, {'projectId': self.uid})["project"]["tagSetStatus"] + + if status == "ENABLED": + return QueueMode.Batch + elif status == "DISABLED": + return QueueMode.Dataset + else: + raise ValueError("Status not known") + def validate_labeling_parameter_overrides(self, data): for idx, row in enumerate(data): if len(row) != 3: diff --git a/tests/data/annotation_types/data/test_text.py b/tests/data/annotation_types/data/test_text.py index 4bca4f939..35dc20a28 100644 --- a/tests/data/annotation_types/data/test_text.py +++ b/tests/data/annotation_types/data/test_text.py @@ -1,3 +1,5 @@ +import os + import pytest from pydantic import ValidationError @@ -22,11 +24,13 @@ def test_url(): assert len(text) == 3541 -def test_file(): - file_path = "tests/data/assets/sample_text.txt" - text_data = TextData(file_path=file_path) - text = text_data.value - assert len(text) == 3541 +def test_file(tmpdir): + content = "foo bar baz" + file = "hello.txt" + dir = tmpdir.mkdir('data') + dir.join(file).write(content) + text_data = TextData(file_path=os.path.join(dir.strpath, file)) + assert len(text_data.value) == len(content) def test_ref(): diff --git a/tests/data/annotation_types/geometry/test_rectangle.py b/tests/data/annotation_types/geometry/test_rectangle.py index 14d2b7316..d8586aeb7 100644 --- a/tests/data/annotation_types/geometry/test_rectangle.py +++ b/tests/data/annotation_types/geometry/test_rectangle.py @@ -1,6 +1,6 @@ -from pydantic import ValidationError -import pytest import cv2 +import pytest +from pydantic import ValidationError from labelbox.data.annotation_types import Point, Rectangle @@ -18,3 +18,7 @@ def test_rectangle(): raster = rectangle.draw(height=32, width=32) assert (cv2.imread("tests/data/assets/rectangle.png") == raster).all() + + xyhw = Rectangle.from_xyhw(0., 0, 10, 10) + assert xyhw.start == Point(x=0, y=0.) + assert xyhw.end == Point(x=10, y=10.0) diff --git a/tests/integration/test_batch.py b/tests/integration/test_batch.py new file mode 100644 index 000000000..8534d8d64 --- /dev/null +++ b/tests/integration/test_batch.py @@ -0,0 +1,30 @@ +import pytest + +from labelbox import Dataset, Project +from labelbox.schema.project import QueueMode + +IMAGE_URL = "https://storage.googleapis.com/diagnostics-demo-data/coco/COCO_train2014_000000000034.jpg" + + +@pytest.fixture +def big_dataset(dataset: Dataset): + task = dataset.create_data_rows([ + { + "row_data": IMAGE_URL, + "external_id": "my-image" + }, + ] * 250) + task.wait_till_done() + + yield dataset + dataset.delete() + + +def test_submit_batch(configured_project: Project, big_dataset): + configured_project.update(queue_mode=QueueMode.Batch) + + data_rows = [dr.uid for dr in list(big_dataset.export_data_rows())] + queue_res = configured_project.queue(data_rows) + assert not len(queue_res) + dequeue_res = configured_project.dequeue(data_rows) + assert not len(dequeue_res) diff --git a/tests/integration/test_project.py b/tests/integration/test_project.py index ef4307dfa..ec7a24ac7 100644 --- a/tests/integration/test_project.py +++ b/tests/integration/test_project.py @@ -1,11 +1,10 @@ import json -import requests -import ndjson import pytest from labelbox import Project, LabelingFrontend from labelbox.exceptions import InvalidQueryError +from labelbox.schema.project import QueueMode def test_project(client, rand_gen): @@ -107,3 +106,9 @@ def test_attach_instructions(client, project): def test_queued_data_row_export(configured_project): result = configured_project.export_queued_data_rows() assert len(result) == 1 + + +def test_queue_mode(configured_project: Project): + assert configured_project.queue_mode() == QueueMode.Dataset + configured_project.update(queue_mode=QueueMode.Batch) + assert configured_project.queue_mode() == QueueMode.Batch