Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 46 additions & 13 deletions tests/functional/benchmark/benchmark_functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pandas as pd
import json
from dotenv import load_dotenv
import time

load_dotenv()
from aixplain.factories import ModelFactory, DatasetFactory, MetricFactory, BenchmarkFactory
Expand Down Expand Up @@ -34,23 +35,55 @@ def run_input_map(request):
def module_input_map(request):
return request.param

def is_job_finshed(benchmark_job):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in method name.

time_taken = 0
sleep_time = 15
timeout = 10 * 60
while True:
if time_taken > timeout:
break
job_status = benchmark_job.check_status()
if job_status == "in_progress":
time.sleep(sleep_time)
time_taken += sleep_time
elif job_status == "completed":
return True
else:
break
return False

def test_run(run_input_map):
def assert_correct_results(benchmark_job):
df = benchmark_job.download_results_as_csv(return_dataframe=True)
assert type(df) is pd.DataFrame, "Couldn't download CSV"
model_success_rate = (sum(df["Model_success"])*100)/len(df.index)
assert model_success_rate > 80 , f"Low model success rate ({model_success_rate})"
metric_name = "BLEU by sacrebleu"
mean_score = df[metric_name].mean()
assert mean_score != 0 , f"Zero Mean Score - Please check metric ({metric_name})"



def test_create_and_run(run_input_map):
model_list = [ModelFactory.get(model_id) for model_id in run_input_map["model_ids"]]
dataset_list = [DatasetFactory.get(dataset_id) for dataset_id in run_input_map["dataset_ids"]]
dataset_list = [DatasetFactory.list(query=dataset_name)["results"][0] for dataset_name in run_input_map["dataset_names"]]
metric_list = [MetricFactory.get(metric_id) for metric_id in run_input_map["metric_ids"]]
benchmark = BenchmarkFactory.create(f"SDK Benchmark Test {uuid.uuid4()}", dataset_list, model_list, metric_list)
assert type(benchmark) is Benchmark
assert type(benchmark) is Benchmark, "Couldn't create benchmark"
benchmark_job = benchmark.start()
assert type(benchmark_job) is BenchmarkJob
assert type(benchmark_job) is BenchmarkJob, "Couldn't start job"
assert is_job_finshed(benchmark_job), "Job did not finish in time"
assert_correct_results(benchmark_job)


def test_module(module_input_map):
benchmark = BenchmarkFactory.get(module_input_map["benchmark_id"])
assert benchmark.id == module_input_map["benchmark_id"]
benchmark_job = benchmark.job_list[0]
assert benchmark_job.benchmark_id == module_input_map["benchmark_id"]
job_status = benchmark_job.check_status()
assert job_status in ["in_progress", "completed"]
df = benchmark_job.download_results_as_csv(return_dataframe=True)
assert type(df) is pd.DataFrame



# def test_module(module_input_map):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add the annotation @pytest.mark.skip() instead of commenting, if you do not want to run the test.

# benchmark = BenchmarkFactory.get(module_input_map["benchmark_id"])
# assert benchmark.id == module_input_map["benchmark_id"]
# benchmark_job = benchmark.job_list[0]
# assert benchmark_job.benchmark_id == module_input_map["benchmark_id"]
# job_status = benchmark_job.check_status()
# assert job_status in ["in_progress", "completed"]
# df = benchmark_job.download_results_as_csv(return_dataframe=True)
# assert type(df) is pd.DataFrame
3 changes: 2 additions & 1 deletion tests/functional/benchmark/data/benchmark_test_run_data.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
{
"model_ids": ["61b097551efecf30109d32da", "60ddefbe8d38c51c5885f98a"],
"dataset_ids": ["64da34a813d879bec2323aa3"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove dataset_ids from here since it is not used anymore.

"dataset_names": ["EnHi SDK Test - Benchmark Dataset"],
"metric_ids": ["639874ab506c987b1ae1acc6", "6408942f166427039206d71e"]
}
]
]