From c0f1b34cf2549313e4f3f0414fd5edff64c9da46 Mon Sep 17 00:00:00 2001 From: Kadir Pekel Date: Fri, 27 Dec 2024 17:26:18 +0100 Subject: [PATCH] ENG-1288: Added diarization test into pipelines functional tests --- tests/functional/pipelines/run_test.py | 91 +++++++++++++++++++++----- 1 file changed, 73 insertions(+), 18 deletions(-) diff --git a/tests/functional/pipelines/run_test.py b/tests/functional/pipelines/run_test.py index 985e4a91..7a1138bf 100644 --- a/tests/functional/pipelines/run_test.py +++ b/tests/functional/pipelines/run_test.py @@ -51,7 +51,9 @@ def test_get_pipeline(): def test_run_single_str(batchmode: bool, version: str): pipeline = PipelineFactory.list(query="SingleNodePipeline")["results"][0] - response = pipeline.run(data="Translate this thing", batch_mode=batchmode, **{"version": version}) + response = pipeline.run( + data="Translate this thing", batch_mode=batchmode, **{"version": version} + ) assert response["status"] == "SUCCESS" @@ -91,7 +93,7 @@ def test_run_with_url(batchmode: bool, version: str): response = pipeline.run( data="https://aixplain-platform-assets.s3.amazonaws.com/data/dev/64c81163f8bdcac7443c2dad/data/f8.txt", batch_mode=batchmode, - **{"version": version} + **{"version": version}, ) assert response["status"] == "SUCCESS" @@ -111,7 +113,12 @@ def test_run_with_dataset(batchmode: bool, version: str): data_id = dataset.source_data["en"].id pipeline = PipelineFactory.list(query="SingleNodePipeline")["results"][0] - response = pipeline.run(data=data_id, data_asset=data_asset_id, batch_mode=batchmode, **{"version": version}) + response = pipeline.run( + data=data_id, + data_asset=data_asset_id, + batch_mode=batchmode, + **{"version": version}, + ) assert response["status"] == "SUCCESS" @@ -130,7 +137,7 @@ def test_run_multipipe_with_strings(batchmode: bool, version: str): response = pipeline.run( data={"Input": "Translate this thing.", "Reference": "Traduza esta coisa."}, batch_mode=batchmode, - **{"version": version} + **{"version": version}, ) assert response["status"] == "SUCCESS" @@ -157,15 +164,20 @@ def test_run_multipipe_with_datasets(batchmode: bool, version: str): data={"Input": input_id, "Reference": reference_id}, data_asset={"Input": data_asset_id, "Reference": data_asset_id}, batch_mode=batchmode, - **{"version": version} + **{"version": version}, ) assert response["status"] == "SUCCESS" @pytest.mark.parametrize("version", ["2.0", "3.0"]) def test_run_segment_reconstruct(version: str): - pipeline = PipelineFactory.list(query="Segmentation/Reconstruction Functional Test - DO NOT DELETE")["results"][0] - response = pipeline.run("https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav", **{"version": version}) + pipeline = PipelineFactory.list( + query="Segmentation/Reconstruction Functional Test - DO NOT DELETE" + )["results"][0] + response = pipeline.run( + "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav", + **{"version": version}, + ) assert response["status"] == "SUCCESS" output = response["data"][0] @@ -179,11 +191,13 @@ def test_run_translation_metric(version: str): reference_id = dataset.target_data["pt"][0].id - pipeline = PipelineFactory.list(query="Translation Metric Functional Test - DO NOT DELETE")["results"][0] + pipeline = PipelineFactory.list( + query="Translation Metric Functional Test - DO NOT DELETE" + )["results"][0] response = pipeline.run( data={"TextInput": reference_id, "ReferenceInput": reference_id}, data_asset={"TextInput": data_asset_id, "ReferenceInput": data_asset_id}, - **{"version": version} + **{"version": version}, ) assert response["status"] == "SUCCESS" @@ -194,13 +208,15 @@ def test_run_translation_metric(version: str): @pytest.mark.parametrize("version", ["2.0", "3.0"]) def test_run_metric(version: str): - pipeline = PipelineFactory.list(query="ASR Metric Functional Test - DO NOT DELETE")["results"][0] + pipeline = PipelineFactory.list(query="ASR Metric Functional Test - DO NOT DELETE")[ + "results" + ][0] response = pipeline.run( { "AudioInput": "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav", "ReferenceInput": "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.txt", }, - **{"version": version} + **{"version": version}, ) assert response["status"] == "SUCCESS" @@ -212,10 +228,26 @@ def test_run_metric(version: str): @pytest.mark.parametrize( "input_data,output_data,version", [ - ("https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav", "AudioOutput", "2.0"), - ("https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.txt", "TextOutput", "2.0"), - ("https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav", "AudioOutput", "3.0"), - ("https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.txt", "TextOutput", "3.0"), + ( + "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav", + "AudioOutput", + "2.0", + ), + ( + "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.txt", + "TextOutput", + "2.0", + ), + ( + "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav", + "AudioOutput", + "3.0", + ), + ( + "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.txt", + "TextOutput", + "3.0", + ), ], ) def test_run_router(input_data: str, output_data: str, version: str): @@ -245,8 +277,13 @@ def test_run_decision(input_data: str, output_data: str, version: str): @pytest.mark.parametrize("version", ["3.0"]) def test_run_script(version: str): - pipeline = PipelineFactory.list(query="Script Functional Test - DO NOT DELETE")["results"][0] - response = pipeline.run("https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav", **{"version": version}) + pipeline = PipelineFactory.list(query="Script Functional Test - DO NOT DELETE")[ + "results" + ][0] + response = pipeline.run( + "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav", + **{"version": version}, + ) assert response["status"] == "SUCCESS" data = response["data"][0]["segments"][0]["response"] @@ -255,7 +292,9 @@ def test_run_script(version: str): @pytest.mark.parametrize("version", ["2.0", "3.0"]) def test_run_text_reconstruction(version: str): - pipeline = PipelineFactory.list(query="Text Reconstruction - DO NOT DELETE")["results"][0] + pipeline = PipelineFactory.list(query="Text Reconstruction - DO NOT DELETE")[ + "results" + ][0] response = pipeline.run("Segment A\nSegment B\nSegment C", **{"version": version}) assert response["status"] == "SUCCESS" @@ -268,3 +307,19 @@ def test_run_text_reconstruction(version: str): for d in response["data"]: assert len(d["segments"]) > 0 assert d["segments"][0]["success"] is True + + +@pytest.mark.parametrize("version", ["3.0"]) +def test_run_diarization(version: str): + pipeline = PipelineFactory.list( + query="Diarization ASR Functional Test - DO NOT DELETE" + )["results"][0] + response = pipeline.run( + "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav", + **{"version": version}, + ) + + assert response["status"] == "SUCCESS" + for d in response["data"]: + assert len(d["segments"]) > 0 + assert d["segments"][0]["success"] is True