Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion dagshub/data_engine/annotation/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Literal, Optional, Union, Sequence, Mapping, Callable, List

from dagshub_annotation_converter.converters.coco import load_coco_from_file
from dagshub_annotation_converter.converters.cvat import load_cvat_from_zip
from dagshub_annotation_converter.converters.yolo import load_yolo_from_fs
from dagshub_annotation_converter.formats.label_studio.task import LabelStudioTask
Expand All @@ -16,7 +17,7 @@
if TYPE_CHECKING:
from dagshub.data_engine.model.datasource import Datasource

AnnotationType = Literal["yolo", "cvat"]
AnnotationType = Literal["yolo", "cvat", "coco"]
AnnotationLocation = Literal["repo", "disk"]


Expand Down Expand Up @@ -85,6 +86,8 @@ def import_annotations(self) -> Mapping[str, Sequence[IRAnnotationBase]]:
)
elif self.annotations_type == "cvat":
annotation_dict = load_cvat_from_zip(annotations_file)
elif self.annotations_type == "coco":
annotation_dict, _ = load_coco_from_file(annotations_file)

return annotation_dict

Expand All @@ -104,6 +107,9 @@ def download_annotations(self, dest_dir: Path):
# Download the annotation data
assert context.path is not None
repoApi.download(self.annotations_file.parent / context.path, dest_dir, keep_source_prefix=True)
elif self.annotations_type == "coco":
# Download just the annotation file
repoApi.download(self.annotations_file.as_posix(), dest_dir, keep_source_prefix=True)

@staticmethod
def determine_load_location(ds: "Datasource", annotations_path: Union[str, Path]) -> AnnotationLocation:
Expand Down
70 changes: 64 additions & 6 deletions dagshub/data_engine/model/query_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import dacite
import dagshub_annotation_converter.converters.yolo
import rich.progress
from dagshub_annotation_converter.converters.coco import export_to_coco_file
from dagshub_annotation_converter.formats.coco import CocoContext
from dagshub_annotation_converter.formats.yolo import YoloContext
from dagshub_annotation_converter.formats.yolo.categories import Categories
from dagshub_annotation_converter.formats.yolo.common import ir_mapping
Expand Down Expand Up @@ -778,6 +780,16 @@ def _get_all_annotations(self, annotation_field: str) -> List[IRImageAnnotationB
annotations.extend(dp.metadata[annotation_field].annotations)
return annotations

def _resolve_annotation_field(self, annotation_field: Optional[str]) -> str:
if annotation_field is not None:
return annotation_field
annotation_fields = sorted([f.name for f in self.fields if f.is_annotation()])
if len(annotation_fields) == 0:
raise ValueError("No annotation fields found in the datasource")
annotation_field = annotation_fields[0]
log_message(f"Using annotations from field {annotation_field}")
return annotation_field

def export_as_yolo(
self,
download_dir: Optional[Union[str, Path]] = None,
Expand All @@ -803,12 +815,7 @@ def export_as_yolo(
Returns:
The path to the YAML file with the metadata. Pass this path to ``YOLO.train()`` to train a model.
"""
if annotation_field is None:
annotation_fields = sorted([f.name for f in self.fields if f.is_annotation()])
if len(annotation_fields) == 0:
raise ValueError("No annotation fields found in the datasource")
annotation_field = annotation_fields[0]
log_message(f"Using annotations from field {annotation_field}")
annotation_field = self._resolve_annotation_field(annotation_field)

if download_dir is None:
download_dir = Path("dagshub_export")
Expand Down Expand Up @@ -861,6 +868,57 @@ def export_as_yolo(
log_message(f"Done! Saved YOLO Dataset, YAML file is at {yaml_path.absolute()}")
return yaml_path

def export_as_coco(
self,
download_dir: Optional[Union[str, Path]] = None,
annotation_field: Optional[str] = None,
output_filename: str = "annotations.json",
classes: Optional[Dict[int, str]] = None,
) -> Path:
"""
Downloads the files and exports annotations in COCO format.

Args:
download_dir: Where to download the files. Defaults to ``./dagshub_export``
annotation_field: Field with the annotations. If None, uses the first alphabetical annotation field.
output_filename: Name of the output COCO JSON file. Default is ``annotations.json``.
classes: Category mapping for the COCO dataset as ``{id: name}``.
If ``None``, categories will be inferred from the annotations.

Returns:
Path to the exported COCO JSON file.
"""
annotation_field = self._resolve_annotation_field(annotation_field)

if download_dir is None:
download_dir = Path("dagshub_export")
download_dir = Path(download_dir)

annotations = self._get_all_annotations(annotation_field)
if not annotations:
raise RuntimeError("No annotations found to export")

context = CocoContext()
if classes is not None:
categories = Categories()
for category_id, category_name in classes.items():
categories.add(category_name, category_id)
context.categories = categories

Comment thread
deanp70 marked this conversation as resolved.
# Add the source prefix to all annotations
for ann in annotations:
ann.filename = os.path.join(self.datasource.source.source_prefix, ann.filename)

image_download_path = download_dir / "data"
log_message("Downloading image files...")
self.download_files(image_download_path)

output_path = download_dir / output_filename
log_message("Exporting COCO annotations...")
result_path = export_to_coco_file(annotations, output_path, context=context)
log_message(f"Done! Saved COCO annotations to {result_path.absolute()}")
return result_path

def to_voxel51_dataset(self, **kwargs) -> "fo.Dataset":
"""
Creates a voxel51 dataset that can be used with\
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import setuptools
import os.path

import setuptools


# Thank you pip contributors
def read(rel_path: str) -> str:
Expand Down Expand Up @@ -41,7 +42,7 @@ def get_version(rel_path: str) -> str:
"python-dateutil",
"boto3",
"semver",
"dagshub-annotation-converter>=0.1.12",
"dagshub-annotation-converter>=0.2.0",
]

extras_require = {
Expand Down
198 changes: 198 additions & 0 deletions tests/data_engine/annotation_import/test_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import datetime
import json
from unittest.mock import patch

import pytest
from dagshub_annotation_converter.ir.image import (
IRBBoxImageAnnotation,
CoordinateStyle,
)

from dagshub.data_engine.annotation.importer import AnnotationImporter, AnnotationsNotFoundError
from dagshub.data_engine.annotation.metadata import MetadataAnnotations
from dagshub.data_engine.client.models import MetadataSelectFieldSchema
from dagshub.data_engine.dtypes import MetadataFieldType, ReservedTags
from dagshub.data_engine.model.datapoint import Datapoint
from dagshub.data_engine.model.query_result import QueryResult


# --- import ---


def test_import_coco_from_file(ds, tmp_path):
_write_coco(tmp_path, _make_coco_json())
importer = AnnotationImporter(ds, "coco", tmp_path / "annotations.json", load_from="disk")
result = importer.import_annotations()

assert "image1.jpg" in result
assert len(result["image1.jpg"]) == 1
assert isinstance(result["image1.jpg"][0], IRBBoxImageAnnotation)


def test_import_coco_nonexistent_raises(ds, tmp_path):
importer = AnnotationImporter(ds, "coco", tmp_path / "nope.json", load_from="disk")
with pytest.raises(AnnotationsNotFoundError):
importer.import_annotations()


def test_coco_convert_to_ls_tasks(ds, tmp_path, mock_dagshub_auth):
importer = AnnotationImporter(ds, "coco", tmp_path / "ann.json", load_from="disk")
bbox = IRBBoxImageAnnotation(
filename="test.jpg", categories={"cat": 1.0},
top=0.1, left=0.1, width=0.2, height=0.2,
image_width=640, image_height=480,
coordinate_style=CoordinateStyle.NORMALIZED,
)
tasks = importer.convert_to_ls_tasks({"test.jpg": [bbox]})

assert "test.jpg" in tasks
task_json = json.loads(tasks["test.jpg"])
assert "annotations" in task_json
assert len(task_json["annotations"]) > 0


# --- _resolve_annotation_field ---


def test_resolve_explicit_field(ds):
qr = _make_qr(ds, [], ann_field="my_ann")
assert qr._resolve_annotation_field("explicit") == "explicit"


def test_resolve_auto_field(ds):
qr = _make_qr(ds, [], ann_field="my_ann")
assert qr._resolve_annotation_field(None) == "my_ann"


def test_resolve_no_fields_raises(ds):
qr = _make_qr(ds, [], ann_field=None)
with pytest.raises(ValueError, match="No annotation fields"):
qr._resolve_annotation_field(None)


def test_resolve_picks_alphabetically_first(ds):
fields = []
for name in ["zebra_ann", "alpha_ann"]:
fields.append(MetadataSelectFieldSchema(
asOf=int(datetime.datetime.now().timestamp()),
autoGenerated=False, originalName=name,
multiple=False, valueType=MetadataFieldType.BLOB,
name=name, tags={ReservedTags.ANNOTATION.value},
))
qr = QueryResult(datasource=ds, _entries=[], fields=fields)
assert qr._resolve_annotation_field(None) == "alpha_ann"


# --- export_as_coco ---


def test_export_coco_bbox_coordinates(ds, tmp_path):
dp = Datapoint(datasource=ds, path="images/test.jpg", datapoint_id=0, metadata={})
ann = IRBBoxImageAnnotation(
filename="images/test.jpg", categories={"cat": 1.0},
top=20.0, left=10.0, width=30.0, height=40.0,
image_width=640, image_height=480,
coordinate_style=CoordinateStyle.DENORMALIZED,
)
dp.metadata["ann"] = MetadataAnnotations(datapoint=dp, field="ann", annotations=[ann])

qr = _make_qr(ds, [dp], ann_field="ann")
with patch.object(qr, "download_files"):
result = qr.export_as_coco(download_dir=tmp_path, annotation_field="ann")

coco = json.loads(result.read_text())
assert coco["annotations"][0]["bbox"] == [10.0, 20.0, 30.0, 40.0]


def test_export_coco_no_annotations_raises(ds, tmp_path):
dp = Datapoint(datasource=ds, path="test.jpg", datapoint_id=0, metadata={})
dp.metadata["ann"] = MetadataAnnotations(datapoint=dp, field="ann", annotations=[])

qr = _make_qr(ds, [dp], ann_field="ann")
with pytest.raises(RuntimeError, match="No annotations found"):
qr.export_as_coco(download_dir=tmp_path, annotation_field="ann")


def test_export_coco_explicit_classes(ds, tmp_path):
dp = Datapoint(datasource=ds, path="images/test.jpg", datapoint_id=0, metadata={})
dp.metadata["ann"] = MetadataAnnotations(
datapoint=dp, field="ann", annotations=[_make_image_bbox("images/test.jpg")]
)

qr = _make_qr(ds, [dp], ann_field="ann")
with patch.object(qr, "download_files"):
result = qr.export_as_coco(
download_dir=tmp_path, annotation_field="ann", classes={1: "cat", 2: "dog"}
)

coco = json.loads(result.read_text())
assert "cat" in {c["name"] for c in coco["categories"]}


def test_export_coco_custom_filename(ds, tmp_path):
dp = Datapoint(datasource=ds, path="images/test.jpg", datapoint_id=0, metadata={})
dp.metadata["ann"] = MetadataAnnotations(
datapoint=dp, field="ann", annotations=[_make_image_bbox("images/test.jpg")]
)

qr = _make_qr(ds, [dp], ann_field="ann")
with patch.object(qr, "download_files"):
result = qr.export_as_coco(
download_dir=tmp_path, annotation_field="ann", output_filename="custom.json"
)

assert result.name == "custom.json"


def test_export_coco_multiple_datapoints(ds, tmp_path):
dps = []
for i, name in enumerate(["a.jpg", "b.jpg"]):
dp = Datapoint(datasource=ds, path=name, datapoint_id=i, metadata={})
dp.metadata["ann"] = MetadataAnnotations(
datapoint=dp, field="ann", annotations=[_make_image_bbox(name)]
)
dps.append(dp)

qr = _make_qr(ds, dps, ann_field="ann")
with patch.object(qr, "download_files"):
result = qr.export_as_coco(download_dir=tmp_path, annotation_field="ann")

coco = json.loads(result.read_text())
assert len(coco["annotations"]) == 2
assert len(coco["images"]) == 2


# --- helpers ---


def _make_coco_json():
return {
"categories": [{"id": 1, "name": "cat"}],
"images": [{"id": 1, "width": 640, "height": 480, "file_name": "image1.jpg"}],
"annotations": [{"id": 1, "image_id": 1, "category_id": 1, "bbox": [10, 20, 30, 40]}],
}


def _write_coco(tmp_path, coco):
(tmp_path / "annotations.json").write_text(json.dumps(coco))


def _make_image_bbox(filename="test.jpg") -> IRBBoxImageAnnotation:
return IRBBoxImageAnnotation(
filename=filename, categories={"cat": 1.0},
top=20.0, left=10.0, width=30.0, height=40.0,
image_width=640, image_height=480,
coordinate_style=CoordinateStyle.DENORMALIZED,
)


def _make_qr(ds, datapoints, ann_field=None):
fields = []
if ann_field:
fields.append(MetadataSelectFieldSchema(
asOf=int(datetime.datetime.now().timestamp()),
autoGenerated=False, originalName=ann_field,
multiple=False, valueType=MetadataFieldType.BLOB,
name=ann_field, tags={ReservedTags.ANNOTATION.value},
))
return QueryResult(datasource=ds, _entries=datapoints, fields=fields)
6 changes: 5 additions & 1 deletion tests/data_engine/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import datetime
from pathlib import PurePosixPath
from unittest.mock import PropertyMock

import pytest

from dagshub.common.api import UserAPI
from dagshub.common.api.responses import UserAPIResponse
from dagshub.data_engine import datasources
from dagshub.data_engine.client.models import MetadataSelectFieldSchema, PreprocessingStatus
from dagshub.data_engine.client.models import DatasourceType, MetadataSelectFieldSchema, PreprocessingStatus
from dagshub.data_engine.model.datapoint import Datapoint
from dagshub.data_engine.model.datasource import DatasetState, Datasource
from dagshub.data_engine.model.query_result import QueryResult
Expand All @@ -26,13 +28,15 @@ def other_ds(mocker, mock_dagshub_auth) -> Datasource:

def _create_mock_datasource(mocker, id, name) -> Datasource:
ds_state = datasources.DatasourceState(id=id, name=name, repo="kirill/repo")
ds_state.source_type = DatasourceType.REPOSITORY
ds_state.path = "repo://kirill/repo/data/"
ds_state.preprocessing_status = PreprocessingStatus.READY
mocker.patch.object(ds_state, "client")
# Stub out get_from_dagshub, because it doesn't need to be done in tests
mocker.patch.object(ds_state, "get_from_dagshub")
# Stub out root path so all the content_path/etc work without also mocking out RepoAPI
mocker.patch.object(ds_state, "_root_path", return_value="http://example.com")
mocker.patch.object(type(ds_state), "source_prefix", new_callable=PropertyMock, return_value=PurePosixPath())
ds_state.repoApi = MockRepoAPI("kirill/repo")
return Datasource(ds_state)

Expand Down
4 changes: 4 additions & 0 deletions tests/mocks/repo_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ def generate_content_api_entry(path, is_dir=False, versioning="dvc") -> ContentA
def default_branch(self) -> str:
return self._default_branch

@property
def id(self) -> int:
return 1

def get_connected_storages(self) -> List[StorageAPIEntry]:
return self.storages

Expand Down
Loading