diff --git a/labelbox/schema/batch.py b/labelbox/schema/batch.py index b4c2373f0..22b6d7a32 100644 --- a/labelbox/schema/batch.py +++ b/labelbox/schema/batch.py @@ -77,7 +77,9 @@ def remove_queued_data_rows(self) -> None: }, experimental=True) - def export_data_rows(self, timeout_seconds=120) -> Generator: + def export_data_rows(self, + timeout_seconds=120, + include_metadata: bool = False) -> Generator: """ Returns a generator that produces all data rows that are currently in this batch. @@ -92,23 +94,24 @@ def export_data_rows(self, timeout_seconds=120) -> Generator: LabelboxError: if the export fails or is unable to download within the specified time. """ id_param = "batchId" - query_str = """mutation GetBatchDataRowsExportUrlPyApi($%s: ID!) - {exportBatchDataRows(data:{batchId: $%s }) {downloadUrl createdAt status}} - """ % (id_param, id_param) + metadata_param = "includeMetadataInput" + query_str = """mutation GetBatchDataRowsExportUrlPyApi($%s: ID!, $%s: Boolean!) + {exportBatchDataRows(data:{batchId: $%s , includeMetadataInput: $%s}) {downloadUrl createdAt status}} + """ % (id_param, metadata_param, id_param, metadata_param) sleep_time = 2 while True: - res = self.client.execute(query_str, {id_param: self.uid}) + res = self.client.execute(query_str, { + id_param: self.uid, + metadata_param: include_metadata + }) res = res["exportBatchDataRows"] if res["status"] == "COMPLETE": download_url = res["downloadUrl"] response = requests.get(download_url) response.raise_for_status() reader = ndjson.reader(StringIO(response.text)) - # TODO: Update result to parse metadataFields when resolver returns - return (Entity.DataRow(self.client, { - **result, 'metadataFields': [], - 'customMetadata': [] - }) for result in reader) + return ( + Entity.DataRow(self.client, result) for result in reader) elif res["status"] == "FAILED": raise LabelboxError("Data row export failed.") diff --git a/labelbox/schema/dataset.py b/labelbox/schema/dataset.py index bcdd9c48a..af6400bd6 100644 --- a/labelbox/schema/dataset.py +++ b/labelbox/schema/dataset.py @@ -462,7 +462,9 @@ def data_row_for_external_id(self, external_id) -> "DataRow": external_id) return data_rows[0] - def export_data_rows(self, timeout_seconds=120) -> Generator: + def export_data_rows(self, + timeout_seconds=120, + include_metadata: bool = False) -> Generator: """ Returns a generator that produces all data rows that are currently attached to this dataset. @@ -477,23 +479,24 @@ def export_data_rows(self, timeout_seconds=120) -> Generator: LabelboxError: if the export fails or is unable to download within the specified time. """ id_param = "datasetId" - query_str = """mutation GetDatasetDataRowsExportUrlPyApi($%s: ID!) - {exportDatasetDataRows(data:{datasetId: $%s }) {downloadUrl createdAt status}} - """ % (id_param, id_param) + metadata_param = "includeMetadataInput" + query_str = """mutation GetDatasetDataRowsExportUrlPyApi($%s: ID!, $%s: Boolean!) + {exportDatasetDataRows(data:{datasetId: $%s , includeMetadataInput: $%s}) {downloadUrl createdAt status}} + """ % (id_param, metadata_param, id_param, metadata_param) sleep_time = 2 while True: - res = self.client.execute(query_str, {id_param: self.uid}) + res = self.client.execute(query_str, { + id_param: self.uid, + metadata_param: include_metadata + }) res = res["exportDatasetDataRows"] if res["status"] == "COMPLETE": download_url = res["downloadUrl"] response = requests.get(download_url) response.raise_for_status() reader = ndjson.reader(StringIO(response.text)) - # TODO: Update result to parse metadataFields when resolver returns - return (Entity.DataRow(self.client, { - **result, 'metadataFields': [], - 'customMetadata': [] - }) for result in reader) + return ( + Entity.DataRow(self.client, result) for result in reader) elif res["status"] == "FAILED": raise LabelboxError("Data row export failed.") diff --git a/labelbox/schema/project.py b/labelbox/schema/project.py index de8654c4c..e391f1b54 100644 --- a/labelbox/schema/project.py +++ b/labelbox/schema/project.py @@ -185,8 +185,10 @@ def labels(self, datasets=None, order_by=None) -> PaginatedCollection: return PaginatedCollection(self.client, query_str, {id_param: self.uid}, ["project", "labels"], Label) - def export_queued_data_rows(self, - timeout_seconds=120) -> List[Dict[str, str]]: + def export_queued_data_rows( + self, + timeout_seconds=120, + include_metadata: bool = False) -> List[Dict[str, str]]: """ Returns all data rows that are currently enqueued for this project. Args: @@ -197,12 +199,16 @@ def export_queued_data_rows(self, LabelboxError: if the export fails or is unable to download within the specified time. """ id_param = "projectId" - query_str = """mutation GetQueuedDataRowsExportUrlPyApi($%s: ID!) - {exportQueuedDataRows(data:{projectId: $%s }) {downloadUrl createdAt status} } - """ % (id_param, id_param) + metadata_param = "includeMetadataInput" + query_str = """mutation GetQueuedDataRowsExportUrlPyApi($%s: ID!, $%s: Boolean!) + {exportQueuedDataRows(data:{projectId: $%s , includeMetadataInput: $%s}) {downloadUrl createdAt status} } + """ % (id_param, metadata_param, id_param, metadata_param) sleep_time = 2 while True: - res = self.client.execute(query_str, {id_param: self.uid}) + res = self.client.execute(query_str, { + id_param: self.uid, + metadata_param: include_metadata + }) res = res["exportQueuedDataRows"] if res["status"] == "COMPLETE": download_url = res["downloadUrl"]