diff --git a/tests/data/serialization/ndjson/test_video.py b/tests/data/serialization/ndjson/test_video.py index a21c2a539..5946d0bab 100644 --- a/tests/data/serialization/ndjson/test_video.py +++ b/tests/data/serialization/ndjson/test_video.py @@ -12,7 +12,6 @@ from labelbox import parser from labelbox.data.serialization.ndjson.converter import NDJsonConverter -from labelbox.schema.annotation_import import MALPredictionImport def test_video(): diff --git a/tests/integration/annotation_import/conftest.py b/tests/integration/annotation_import/conftest.py index d50c44d0c..e0b3e4a18 100644 --- a/tests/integration/annotation_import/conftest.py +++ b/tests/integration/annotation_import/conftest.py @@ -508,8 +508,12 @@ def configured_project(client, initial_dataset, ontology, rand_gen, image_url): data_row_ids = [] - for _ in range(len(ontology['tools']) + len(ontology['classifications'])): - data_row_ids.append(dataset.create_data_row(row_data=image_url).uid) + ontologies = ontology['tools'] + ontology['classifications'] + for ind in range(len(ontologies)): + data_row_ids.append( + dataset.create_data_row( + row_data=image_url, + global_key=f"gk_{ontologies[ind]['name']}_{rand_gen(str)}").uid) project._wait_until_data_rows_are_processed(data_row_ids=data_row_ids, sleep_interval=3) diff --git a/tests/integration/annotation_import/test_data_types.py b/tests/integration/annotation_import/test_data_types.py index 79e8b03cb..f4ac3c82f 100644 --- a/tests/integration/annotation_import/test_data_types.py +++ b/tests/integration/annotation_import/test_data_types.py @@ -180,6 +180,51 @@ def test_import_data_types( data_row.delete() +def test_import_data_types_by_global_key( + client, + configured_project, + initial_dataset, + rand_gen, + data_row_json_by_data_type, + annotations_by_data_type, +): + + project = configured_project + project_id = project.uid + dataset = initial_dataset + data_type_class = ImageData + set_project_media_type_from_data_type(project, data_type_class) + + data_row_ndjson = data_row_json_by_data_type['image'] + data_row_ndjson['global_key'] = str(uuid.uuid4()) + data_row = create_data_row_for_project(project, dataset, data_row_ndjson, + rand_gen(str)) + + annotations_ndjson = annotations_by_data_type['image'] + annotations_list = [ + label.annotations + for label in NDJsonConverter.deserialize(annotations_ndjson) + ] + labels = [ + lb_types.Label(data=data_type_class(global_key=data_row.global_key), + annotations=annotations) + for annotations in annotations_list + ] + + label_import = lb.LabelImport.create_from_objects(client, project_id, + f'test-import-image', + labels) + label_import.wait_until_done() + + assert label_import.state == AnnotationImportState.FINISHED + assert len(label_import.errors) == 0 + exported_labels = project.export_labels(download=True) + objects = exported_labels[0]['Label']['objects'] + classifications = exported_labels[0]['Label']['classifications'] + assert len(objects) + len(classifications) == len(labels) + data_row.delete() + + def validate_iso_format(date_string: str): parsed_t = datetime.datetime.fromisoformat( date_string) #this will blow up if the string is not in iso format @@ -321,6 +366,17 @@ def one_datarow(client, rand_gen, data_row_json_by_data_type, data_type): dataset.delete() +@pytest.fixture +def one_datarow_global_key(client, rand_gen, data_row_json_by_data_type): + dataset = client.create_dataset(name=rand_gen(str)) + data_row_json = data_row_json_by_data_type['video'] + data_row = dataset.create_data_row(data_row_json) + + yield data_row + + dataset.delete() + + @pytest.mark.parametrize('data_type, data_class, annotations', test_params) def test_import_mal_annotations(client, configured_project_with_one_data_row, data_type, data_class, annotations, rand_gen, @@ -348,3 +404,33 @@ def test_import_mal_annotations(client, configured_project_with_one_data_row, assert import_annotations.errors == [] # MAL Labels cannot be exported and compared to input labels + + +def test_import_mal_annotations_global_key(client, + configured_project_with_one_data_row, + rand_gen, one_datarow_global_key): + data_class = lb_types.VideoData + data_row = one_datarow_global_key + annotations = [video_mask_annotation] + set_project_media_type_from_data_type(configured_project_with_one_data_row, + data_class) + + configured_project_with_one_data_row.create_batch( + rand_gen(str), + [data_row.uid], + ) + + labels = [ + lb_types.Label(data=data_class(global_key=data_row.global_key), + annotations=annotations) + ] + + import_annotations = lb.MALPredictionImport.create_from_objects( + client=client, + project_id=configured_project_with_one_data_row.uid, + name=f"import {str(uuid.uuid4())}", + predictions=labels) + import_annotations.wait_until_done() + + assert import_annotations.errors == [] + # MAL Labels cannot be exported and compared to input labels diff --git a/tests/integration/annotation_import/test_mea_prediction_import.py b/tests/integration/annotation_import/test_mea_prediction_import.py index c457130be..d39e1ad8e 100644 --- a/tests/integration/annotation_import/test_mea_prediction_import.py +++ b/tests/integration/annotation_import/test_mea_prediction_import.py @@ -37,6 +37,25 @@ def test_create_from_objects(model_run_with_data_rows, object_predictions, annotation_import.wait_until_done() +def test_create_from_objects_global_key(client, model_run_with_data_rows, + entity_inference, + annotation_import_test_helpers): + name = str(uuid.uuid4()) + dr = client.get_data_row(entity_inference['dataRow']['id']) + del entity_inference['dataRow']['id'] + entity_inference['dataRow']['globalKey'] = dr.global_key + object_predictions = [entity_inference] + + annotation_import = model_run_with_data_rows.add_predictions( + name=name, predictions=object_predictions) + + assert annotation_import.model_run_id == model_run_with_data_rows.uid + annotation_import_test_helpers.check_running_state(annotation_import, name) + annotation_import_test_helpers.assert_file_content( + annotation_import.input_file_url, object_predictions) + annotation_import.wait_until_done() + + def test_create_from_objects_with_confidence(predictions_with_confidence, model_run_with_data_rows, annotation_import_test_helpers): diff --git a/tests/integration/annotation_import/test_upsert_prediction_import.py b/tests/integration/annotation_import/test_upsert_prediction_import.py index 927b6526d..55f227315 100644 --- a/tests/integration/annotation_import/test_upsert_prediction_import.py +++ b/tests/integration/annotation_import/test_upsert_prediction_import.py @@ -1,8 +1,6 @@ import uuid from labelbox import parser import pytest - -from labelbox.schema.annotation_import import AnnotationImportState, MEAPredictionImport """ - Here we only want to check that the uploads are calling the validation - Then with unit tests we can check the types of errors raised @@ -28,7 +26,7 @@ def test_create_from_url(client, tmp_path, object_predictions, if p['dataRow']['id'] in model_run_data_rows ] with file_path.open("w") as f: - ndjson.dump(predictions, f) + parser.dump(predictions, f) # Needs to have data row ids @@ -114,7 +112,7 @@ def test_create_from_local_file(tmp_path, model_run_with_data_rows, ] with file_path.open("w") as f: - ndjson.dump(predictions, f) + parser.dump(predictions, f) annotation_import, batch, mal_prediction_import = model_run_with_data_rows.upsert_predictions_and_send_to_project( name=name,