Skip to content

Commit

Permalink
M 5608818958 update benchmark tests (#98)
Browse files Browse the repository at this point in the history
* Update benchmark_functional_test.py

* Update benchmark_test_run_data.json
  • Loading branch information
shreyasXplain authored Dec 19, 2023
1 parent f668332 commit b53ca78
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 14 deletions.
59 changes: 46 additions & 13 deletions tests/functional/benchmark/benchmark_functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pandas as pd
import json
from dotenv import load_dotenv
import time

load_dotenv()
from aixplain.factories import ModelFactory, DatasetFactory, MetricFactory, BenchmarkFactory
Expand Down Expand Up @@ -34,23 +35,55 @@ def run_input_map(request):
def module_input_map(request):
return request.param

def is_job_finshed(benchmark_job):
time_taken = 0
sleep_time = 15
timeout = 10 * 60
while True:
if time_taken > timeout:
break
job_status = benchmark_job.check_status()
if job_status == "in_progress":
time.sleep(sleep_time)
time_taken += sleep_time
elif job_status == "completed":
return True
else:
break
return False

def test_run(run_input_map):
def assert_correct_results(benchmark_job):
df = benchmark_job.download_results_as_csv(return_dataframe=True)
assert type(df) is pd.DataFrame, "Couldn't download CSV"
model_success_rate = (sum(df["Model_success"])*100)/len(df.index)
assert model_success_rate > 80 , f"Low model success rate ({model_success_rate})"
metric_name = "BLEU by sacrebleu"
mean_score = df[metric_name].mean()
assert mean_score != 0 , f"Zero Mean Score - Please check metric ({metric_name})"



def test_create_and_run(run_input_map):
model_list = [ModelFactory.get(model_id) for model_id in run_input_map["model_ids"]]
dataset_list = [DatasetFactory.get(dataset_id) for dataset_id in run_input_map["dataset_ids"]]
dataset_list = [DatasetFactory.list(query=dataset_name)["results"][0] for dataset_name in run_input_map["dataset_names"]]
metric_list = [MetricFactory.get(metric_id) for metric_id in run_input_map["metric_ids"]]
benchmark = BenchmarkFactory.create(f"SDK Benchmark Test {uuid.uuid4()}", dataset_list, model_list, metric_list)
assert type(benchmark) is Benchmark
assert type(benchmark) is Benchmark, "Couldn't create benchmark"
benchmark_job = benchmark.start()
assert type(benchmark_job) is BenchmarkJob
assert type(benchmark_job) is BenchmarkJob, "Couldn't start job"
assert is_job_finshed(benchmark_job), "Job did not finish in time"
assert_correct_results(benchmark_job)


def test_module(module_input_map):
benchmark = BenchmarkFactory.get(module_input_map["benchmark_id"])
assert benchmark.id == module_input_map["benchmark_id"]
benchmark_job = benchmark.job_list[0]
assert benchmark_job.benchmark_id == module_input_map["benchmark_id"]
job_status = benchmark_job.check_status()
assert job_status in ["in_progress", "completed"]
df = benchmark_job.download_results_as_csv(return_dataframe=True)
assert type(df) is pd.DataFrame



# def test_module(module_input_map):
# benchmark = BenchmarkFactory.get(module_input_map["benchmark_id"])
# assert benchmark.id == module_input_map["benchmark_id"]
# benchmark_job = benchmark.job_list[0]
# assert benchmark_job.benchmark_id == module_input_map["benchmark_id"]
# job_status = benchmark_job.check_status()
# assert job_status in ["in_progress", "completed"]
# df = benchmark_job.download_results_as_csv(return_dataframe=True)
# assert type(df) is pd.DataFrame
3 changes: 2 additions & 1 deletion tests/functional/benchmark/data/benchmark_test_run_data.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
{
"model_ids": ["61b097551efecf30109d32da", "60ddefbe8d38c51c5885f98a"],
"dataset_ids": ["64da34a813d879bec2323aa3"],
"dataset_names": ["EnHi SDK Test - Benchmark Dataset"],
"metric_ids": ["639874ab506c987b1ae1acc6", "6408942f166427039206d71e"]
}
]
]

0 comments on commit b53ca78

Please sign in to comment.