diff --git a/dagshub/data_engine/annotation/importer.py b/dagshub/data_engine/annotation/importer.py index c19212de..90661df1 100644 --- a/dagshub/data_engine/annotation/importer.py +++ b/dagshub/data_engine/annotation/importer.py @@ -3,6 +3,7 @@ from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, Literal, Optional, Union, Sequence, Mapping, Callable, List +from dagshub_annotation_converter.converters.coco import load_coco_from_file from dagshub_annotation_converter.converters.cvat import load_cvat_from_zip from dagshub_annotation_converter.converters.yolo import load_yolo_from_fs from dagshub_annotation_converter.formats.label_studio.task import LabelStudioTask @@ -16,7 +17,7 @@ if TYPE_CHECKING: from dagshub.data_engine.model.datasource import Datasource -AnnotationType = Literal["yolo", "cvat"] +AnnotationType = Literal["yolo", "cvat", "coco"] AnnotationLocation = Literal["repo", "disk"] @@ -85,6 +86,8 @@ def import_annotations(self) -> Mapping[str, Sequence[IRAnnotationBase]]: ) elif self.annotations_type == "cvat": annotation_dict = load_cvat_from_zip(annotations_file) + elif self.annotations_type == "coco": + annotation_dict, _ = load_coco_from_file(annotations_file) return annotation_dict @@ -104,6 +107,9 @@ def download_annotations(self, dest_dir: Path): # Download the annotation data assert context.path is not None repoApi.download(self.annotations_file.parent / context.path, dest_dir, keep_source_prefix=True) + elif self.annotations_type == "coco": + # Download just the annotation file + repoApi.download(self.annotations_file.as_posix(), dest_dir, keep_source_prefix=True) @staticmethod def determine_load_location(ds: "Datasource", annotations_path: Union[str, Path]) -> AnnotationLocation: diff --git a/dagshub/data_engine/model/query_result.py b/dagshub/data_engine/model/query_result.py index b986b5c3..ddec542c 100644 --- a/dagshub/data_engine/model/query_result.py +++ b/dagshub/data_engine/model/query_result.py @@ -15,6 +15,8 @@ import dacite import dagshub_annotation_converter.converters.yolo import rich.progress +from dagshub_annotation_converter.converters.coco import export_to_coco_file +from dagshub_annotation_converter.formats.coco import CocoContext from dagshub_annotation_converter.formats.yolo import YoloContext from dagshub_annotation_converter.formats.yolo.categories import Categories from dagshub_annotation_converter.formats.yolo.common import ir_mapping @@ -778,6 +780,16 @@ def _get_all_annotations(self, annotation_field: str) -> List[IRImageAnnotationB annotations.extend(dp.metadata[annotation_field].annotations) return annotations + def _resolve_annotation_field(self, annotation_field: Optional[str]) -> str: + if annotation_field is not None: + return annotation_field + annotation_fields = sorted([f.name for f in self.fields if f.is_annotation()]) + if len(annotation_fields) == 0: + raise ValueError("No annotation fields found in the datasource") + annotation_field = annotation_fields[0] + log_message(f"Using annotations from field {annotation_field}") + return annotation_field + def export_as_yolo( self, download_dir: Optional[Union[str, Path]] = None, @@ -803,12 +815,7 @@ def export_as_yolo( Returns: The path to the YAML file with the metadata. Pass this path to ``YOLO.train()`` to train a model. """ - if annotation_field is None: - annotation_fields = sorted([f.name for f in self.fields if f.is_annotation()]) - if len(annotation_fields) == 0: - raise ValueError("No annotation fields found in the datasource") - annotation_field = annotation_fields[0] - log_message(f"Using annotations from field {annotation_field}") + annotation_field = self._resolve_annotation_field(annotation_field) if download_dir is None: download_dir = Path("dagshub_export") @@ -861,6 +868,57 @@ def export_as_yolo( log_message(f"Done! Saved YOLO Dataset, YAML file is at {yaml_path.absolute()}") return yaml_path + def export_as_coco( + self, + download_dir: Optional[Union[str, Path]] = None, + annotation_field: Optional[str] = None, + output_filename: str = "annotations.json", + classes: Optional[Dict[int, str]] = None, + ) -> Path: + """ + Downloads the files and exports annotations in COCO format. + + Args: + download_dir: Where to download the files. Defaults to ``./dagshub_export`` + annotation_field: Field with the annotations. If None, uses the first alphabetical annotation field. + output_filename: Name of the output COCO JSON file. Default is ``annotations.json``. + classes: Category mapping for the COCO dataset as ``{id: name}``. + If ``None``, categories will be inferred from the annotations. + + Returns: + Path to the exported COCO JSON file. + """ + annotation_field = self._resolve_annotation_field(annotation_field) + + if download_dir is None: + download_dir = Path("dagshub_export") + download_dir = Path(download_dir) + + annotations = self._get_all_annotations(annotation_field) + if not annotations: + raise RuntimeError("No annotations found to export") + + context = CocoContext() + if classes is not None: + categories = Categories() + for category_id, category_name in classes.items(): + categories.add(category_name, category_id) + context.categories = categories + + # Add the source prefix to all annotations + for ann in annotations: + ann.filename = os.path.join(self.datasource.source.source_prefix, ann.filename) + + image_download_path = download_dir / "data" + log_message("Downloading image files...") + self.download_files(image_download_path) + + output_path = download_dir / output_filename + log_message("Exporting COCO annotations...") + result_path = export_to_coco_file(annotations, output_path, context=context) + log_message(f"Done! Saved COCO annotations to {result_path.absolute()}") + return result_path + def to_voxel51_dataset(self, **kwargs) -> "fo.Dataset": """ Creates a voxel51 dataset that can be used with\ diff --git a/setup.py b/setup.py index 6cdef855..270ae5ef 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ -import setuptools import os.path +import setuptools + # Thank you pip contributors def read(rel_path: str) -> str: @@ -41,7 +42,7 @@ def get_version(rel_path: str) -> str: "python-dateutil", "boto3", "semver", - "dagshub-annotation-converter>=0.1.12", + "dagshub-annotation-converter>=0.2.0", ] extras_require = { diff --git a/tests/data_engine/annotation_import/test_coco.py b/tests/data_engine/annotation_import/test_coco.py new file mode 100644 index 00000000..0db9cf8f --- /dev/null +++ b/tests/data_engine/annotation_import/test_coco.py @@ -0,0 +1,198 @@ +import datetime +import json +from unittest.mock import patch + +import pytest +from dagshub_annotation_converter.ir.image import ( + IRBBoxImageAnnotation, + CoordinateStyle, +) + +from dagshub.data_engine.annotation.importer import AnnotationImporter, AnnotationsNotFoundError +from dagshub.data_engine.annotation.metadata import MetadataAnnotations +from dagshub.data_engine.client.models import MetadataSelectFieldSchema +from dagshub.data_engine.dtypes import MetadataFieldType, ReservedTags +from dagshub.data_engine.model.datapoint import Datapoint +from dagshub.data_engine.model.query_result import QueryResult + + +# --- import --- + + +def test_import_coco_from_file(ds, tmp_path): + _write_coco(tmp_path, _make_coco_json()) + importer = AnnotationImporter(ds, "coco", tmp_path / "annotations.json", load_from="disk") + result = importer.import_annotations() + + assert "image1.jpg" in result + assert len(result["image1.jpg"]) == 1 + assert isinstance(result["image1.jpg"][0], IRBBoxImageAnnotation) + + +def test_import_coco_nonexistent_raises(ds, tmp_path): + importer = AnnotationImporter(ds, "coco", tmp_path / "nope.json", load_from="disk") + with pytest.raises(AnnotationsNotFoundError): + importer.import_annotations() + + +def test_coco_convert_to_ls_tasks(ds, tmp_path, mock_dagshub_auth): + importer = AnnotationImporter(ds, "coco", tmp_path / "ann.json", load_from="disk") + bbox = IRBBoxImageAnnotation( + filename="test.jpg", categories={"cat": 1.0}, + top=0.1, left=0.1, width=0.2, height=0.2, + image_width=640, image_height=480, + coordinate_style=CoordinateStyle.NORMALIZED, + ) + tasks = importer.convert_to_ls_tasks({"test.jpg": [bbox]}) + + assert "test.jpg" in tasks + task_json = json.loads(tasks["test.jpg"]) + assert "annotations" in task_json + assert len(task_json["annotations"]) > 0 + + +# --- _resolve_annotation_field --- + + +def test_resolve_explicit_field(ds): + qr = _make_qr(ds, [], ann_field="my_ann") + assert qr._resolve_annotation_field("explicit") == "explicit" + + +def test_resolve_auto_field(ds): + qr = _make_qr(ds, [], ann_field="my_ann") + assert qr._resolve_annotation_field(None) == "my_ann" + + +def test_resolve_no_fields_raises(ds): + qr = _make_qr(ds, [], ann_field=None) + with pytest.raises(ValueError, match="No annotation fields"): + qr._resolve_annotation_field(None) + + +def test_resolve_picks_alphabetically_first(ds): + fields = [] + for name in ["zebra_ann", "alpha_ann"]: + fields.append(MetadataSelectFieldSchema( + asOf=int(datetime.datetime.now().timestamp()), + autoGenerated=False, originalName=name, + multiple=False, valueType=MetadataFieldType.BLOB, + name=name, tags={ReservedTags.ANNOTATION.value}, + )) + qr = QueryResult(datasource=ds, _entries=[], fields=fields) + assert qr._resolve_annotation_field(None) == "alpha_ann" + + +# --- export_as_coco --- + + +def test_export_coco_bbox_coordinates(ds, tmp_path): + dp = Datapoint(datasource=ds, path="images/test.jpg", datapoint_id=0, metadata={}) + ann = IRBBoxImageAnnotation( + filename="images/test.jpg", categories={"cat": 1.0}, + top=20.0, left=10.0, width=30.0, height=40.0, + image_width=640, image_height=480, + coordinate_style=CoordinateStyle.DENORMALIZED, + ) + dp.metadata["ann"] = MetadataAnnotations(datapoint=dp, field="ann", annotations=[ann]) + + qr = _make_qr(ds, [dp], ann_field="ann") + with patch.object(qr, "download_files"): + result = qr.export_as_coco(download_dir=tmp_path, annotation_field="ann") + + coco = json.loads(result.read_text()) + assert coco["annotations"][0]["bbox"] == [10.0, 20.0, 30.0, 40.0] + + +def test_export_coco_no_annotations_raises(ds, tmp_path): + dp = Datapoint(datasource=ds, path="test.jpg", datapoint_id=0, metadata={}) + dp.metadata["ann"] = MetadataAnnotations(datapoint=dp, field="ann", annotations=[]) + + qr = _make_qr(ds, [dp], ann_field="ann") + with pytest.raises(RuntimeError, match="No annotations found"): + qr.export_as_coco(download_dir=tmp_path, annotation_field="ann") + + +def test_export_coco_explicit_classes(ds, tmp_path): + dp = Datapoint(datasource=ds, path="images/test.jpg", datapoint_id=0, metadata={}) + dp.metadata["ann"] = MetadataAnnotations( + datapoint=dp, field="ann", annotations=[_make_image_bbox("images/test.jpg")] + ) + + qr = _make_qr(ds, [dp], ann_field="ann") + with patch.object(qr, "download_files"): + result = qr.export_as_coco( + download_dir=tmp_path, annotation_field="ann", classes={1: "cat", 2: "dog"} + ) + + coco = json.loads(result.read_text()) + assert "cat" in {c["name"] for c in coco["categories"]} + + +def test_export_coco_custom_filename(ds, tmp_path): + dp = Datapoint(datasource=ds, path="images/test.jpg", datapoint_id=0, metadata={}) + dp.metadata["ann"] = MetadataAnnotations( + datapoint=dp, field="ann", annotations=[_make_image_bbox("images/test.jpg")] + ) + + qr = _make_qr(ds, [dp], ann_field="ann") + with patch.object(qr, "download_files"): + result = qr.export_as_coco( + download_dir=tmp_path, annotation_field="ann", output_filename="custom.json" + ) + + assert result.name == "custom.json" + + +def test_export_coco_multiple_datapoints(ds, tmp_path): + dps = [] + for i, name in enumerate(["a.jpg", "b.jpg"]): + dp = Datapoint(datasource=ds, path=name, datapoint_id=i, metadata={}) + dp.metadata["ann"] = MetadataAnnotations( + datapoint=dp, field="ann", annotations=[_make_image_bbox(name)] + ) + dps.append(dp) + + qr = _make_qr(ds, dps, ann_field="ann") + with patch.object(qr, "download_files"): + result = qr.export_as_coco(download_dir=tmp_path, annotation_field="ann") + + coco = json.loads(result.read_text()) + assert len(coco["annotations"]) == 2 + assert len(coco["images"]) == 2 + + +# --- helpers --- + + +def _make_coco_json(): + return { + "categories": [{"id": 1, "name": "cat"}], + "images": [{"id": 1, "width": 640, "height": 480, "file_name": "image1.jpg"}], + "annotations": [{"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 30, 40]}], + } + + +def _write_coco(tmp_path, coco): + (tmp_path / "annotations.json").write_text(json.dumps(coco)) + + +def _make_image_bbox(filename="test.jpg") -> IRBBoxImageAnnotation: + return IRBBoxImageAnnotation( + filename=filename, categories={"cat": 1.0}, + top=20.0, left=10.0, width=30.0, height=40.0, + image_width=640, image_height=480, + coordinate_style=CoordinateStyle.DENORMALIZED, + ) + + +def _make_qr(ds, datapoints, ann_field=None): + fields = [] + if ann_field: + fields.append(MetadataSelectFieldSchema( + asOf=int(datetime.datetime.now().timestamp()), + autoGenerated=False, originalName=ann_field, + multiple=False, valueType=MetadataFieldType.BLOB, + name=ann_field, tags={ReservedTags.ANNOTATION.value}, + )) + return QueryResult(datasource=ds, _entries=datapoints, fields=fields) diff --git a/tests/data_engine/conftest.py b/tests/data_engine/conftest.py index e8f0c70a..02ee8331 100644 --- a/tests/data_engine/conftest.py +++ b/tests/data_engine/conftest.py @@ -1,11 +1,13 @@ import datetime +from pathlib import PurePosixPath +from unittest.mock import PropertyMock import pytest from dagshub.common.api import UserAPI from dagshub.common.api.responses import UserAPIResponse from dagshub.data_engine import datasources -from dagshub.data_engine.client.models import MetadataSelectFieldSchema, PreprocessingStatus +from dagshub.data_engine.client.models import DatasourceType, MetadataSelectFieldSchema, PreprocessingStatus from dagshub.data_engine.model.datapoint import Datapoint from dagshub.data_engine.model.datasource import DatasetState, Datasource from dagshub.data_engine.model.query_result import QueryResult @@ -26,6 +28,7 @@ def other_ds(mocker, mock_dagshub_auth) -> Datasource: def _create_mock_datasource(mocker, id, name) -> Datasource: ds_state = datasources.DatasourceState(id=id, name=name, repo="kirill/repo") + ds_state.source_type = DatasourceType.REPOSITORY ds_state.path = "repo://kirill/repo/data/" ds_state.preprocessing_status = PreprocessingStatus.READY mocker.patch.object(ds_state, "client") @@ -33,6 +36,7 @@ def _create_mock_datasource(mocker, id, name) -> Datasource: mocker.patch.object(ds_state, "get_from_dagshub") # Stub out root path so all the content_path/etc work without also mocking out RepoAPI mocker.patch.object(ds_state, "_root_path", return_value="http://example.com") + mocker.patch.object(type(ds_state), "source_prefix", new_callable=PropertyMock, return_value=PurePosixPath()) ds_state.repoApi = MockRepoAPI("kirill/repo") return Datasource(ds_state) diff --git a/tests/mocks/repo_api.py b/tests/mocks/repo_api.py index d457d161..22b6c94c 100644 --- a/tests/mocks/repo_api.py +++ b/tests/mocks/repo_api.py @@ -113,6 +113,10 @@ def generate_content_api_entry(path, is_dir=False, versioning="dvc") -> ContentA def default_branch(self) -> str: return self._default_branch + @property + def id(self) -> int: + return 1 + def get_connected_storages(self) -> List[StorageAPIEntry]: return self.storages