diff --git a/labelbox/schema/task.py b/labelbox/schema/task.py index 88dce6b4a..ea4b83dfe 100644 --- a/labelbox/schema/task.py +++ b/labelbox/schema/task.py @@ -40,6 +40,7 @@ class Task(DbObject): status = Field.String("status") completion_percentage = Field.Float("completion_percentage") result_url = Field.String("result_url", "result") + errors_url = Field.String("errors_url", "errors") type = Field.String("type") _user: Optional["User"] = None @@ -66,7 +67,9 @@ def wait_till_done(self, timeout_seconds: int = 300) -> None: check_frequency = 2 # frequency of checking, in seconds while True: if self.status != "IN_PROGRESS": - if self.errors is not None: + # self.errors fetches the error content. + # This first condition prevents us from downloading the content for v2 exports + if self.errors_url is not None or self.errors is not None: logger.warning( "There are errors present. Please look at `task.errors` for more details" ) @@ -84,21 +87,18 @@ def wait_till_done(self, timeout_seconds: int = 300) -> None: def errors(self) -> Optional[Dict[str, Any]]: """ Fetch the error associated with an import task. """ - if self.type == "add-data-rows-to-batch" or self.type == "send-to-task-queue": + if self.name == 'JSON Import': + if self.status == "FAILED": + result = self._fetch_remote_json() + return result["error"] + elif self.status == "COMPLETE": + return self.failed_data_rows + elif self.type == "export-data-rows": + return self._fetch_remote_json(remote_json_field='errors_url') + elif self.type == "add-data-rows-to-batch" or self.type == "send-to-task-queue": if self.status == "FAILED": # for these tasks, the error is embedded in the result itself return json.loads(self.result_url) - return None - - # TODO: We should handle error messages for export v2 tasks in the future. - if self.name != 'JSON Import': - return None - - if self.status == "FAILED": - result = self._fetch_remote_json() - return result["error"] - elif self.status == "COMPLETE": - return self.failed_data_rows return None @property @@ -130,37 +130,48 @@ def failed_data_rows(self) -> Optional[Dict[str, Any]]: return None @lru_cache() - def _fetch_remote_json(self) -> Dict[str, Any]: + def _fetch_remote_json(self, + remote_json_field: Optional[str] = None + ) -> Dict[str, Any]: """ Function for fetching and caching the result data. """ - def download_result(): - response = requests.get(self.result_url) + def download_result(remote_json_field: Optional[str], format: str): + url = getattr(self, remote_json_field or 'result_url') + + if url is None: + return None + + response = requests.get(url) response.raise_for_status() - try: + if format == 'json': return response.json() - except Exception as e: - pass - try: + elif format == 'ndjson': return ndjson.loads(response.text) - except Exception as e: - raise ValueError("Failed to parse task JSON/NDJSON result.") + else: + raise ValueError( + "Expected the result format to be either `ndjson` or `json`." + ) - if self.name != 'JSON Import' and self.type != 'export-data-rows': + if self.name == 'JSON Import': + format = 'json' + elif self.type == 'export-data-rows': + format = 'ndjson' + else: raise ValueError( "Task result is only supported for `JSON Import` and `export` tasks." " Download task.result_url manually to access the result for other tasks." ) if self.status != "IN_PROGRESS": - return download_result() + return download_result(remote_json_field, format) else: self.wait_till_done(timeout_seconds=600) if self.status == "IN_PROGRESS": raise ValueError( "Job status still in `IN_PROGRESS`. The result is not available. Call task.wait_till_done() with a larger timeout or contact support." ) - return download_result() + return download_result(remote_json_field, format) @staticmethod def get_task(client, task_id): diff --git a/tests/integration/annotation_import/test_model_run.py b/tests/integration/annotation_import/test_model_run.py index d7cb48d90..14f5a065e 100644 --- a/tests/integration/annotation_import/test_model_run.py +++ b/tests/integration/annotation_import/test_model_run.py @@ -173,14 +173,9 @@ def test_model_run_export_v2(model_run_with_model_run_data_rows, assert task.name == task_name task.wait_till_done() assert task.status == "COMPLETE" + assert task.errors is None - def download_result(result_url): - response = requests.get(result_url) - response.raise_for_status() - data = [json.loads(line) for line in response.text.splitlines()] - return data - - task_results = download_result(task.result_url) + task_results = task.result label_ids = [label.uid for label in configured_project.labels()] label_ids_set = set(label_ids) diff --git a/tests/integration/test_project.py b/tests/integration/test_project.py index ce168fb73..17b661cea 100644 --- a/tests/integration/test_project.py +++ b/tests/integration/test_project.py @@ -60,16 +60,9 @@ def test_project_export_v2(configured_project_with_label): assert task.name == task_name task.wait_till_done() assert task.status == "COMPLETE" + assert task.errors is None - def download_result(result_url): - response = requests.get(result_url) - response.raise_for_status() - data = [json.loads(line) for line in response.text.splitlines()] - return data - - task_results = download_result(task.result_url) - - for task_result in task_results: + for task_result in task.result: task_project = task_result['projects'][project.uid] task_project_label_ids_set = set( map(lambda prediction: prediction['id'], task_project['labels']))