Skip to content

Commit

Permalink
Merge pull request #605 from JohnSnowLabs/refactor/change-runtime-spe…
Browse files Browse the repository at this point in the history
…ed-into-a-test

Refactor/change runtime speed into a test
  • Loading branch information
ArshaanNazir committed Jul 31, 2023
2 parents a75e887 + 0f43910 commit bd45e29
Show file tree
Hide file tree
Showing 12 changed files with 6,967 additions and 1,703 deletions.
4,126 changes: 4,126 additions & 0 deletions demo/tutorials/misc/PerformanceTest_Notebook.ipynb

Large diffs are not rendered by default.

1,551 changes: 0 additions & 1,551 deletions demo/tutorials/misc/RuntimeTest_Notebook.ipynb

This file was deleted.

2,421 changes: 2,421 additions & 0 deletions demo/tutorials/misc/Templatic_Augmentation_Notebook.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions langtest/augmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def fix(self, input_path: str, output_path, export_mode: str = "add"):
test_type["robustness"]["swap_entities"]["parameters"][
"labels"
] = [self.label[each]]
res, _ = TestFactory.transform(
res = TestFactory.transform(
self.task, [hash_map[each]], test_type
)
hash_map[each] = res[0]
Expand All @@ -160,7 +160,7 @@ def fix(self, input_path: str, output_path, export_mode: str = "add"):
]
else:
sample_data = random.choices(data, k=int(sample_length))
aug_data, _ = TestFactory.transform(self.task, sample_data, test_type)
aug_data = TestFactory.transform(self.task, sample_data, test_type)
final_aug_data.extend(aug_data)

if export_mode == "inplace":
Expand Down
2 changes: 1 addition & 1 deletion langtest/data/config/translation_transformers_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ model_parameters:

tests:
defaults:
min_pass_rate: 1.0,
min_pass_rate: 1.0
robustness:
add_typo:
min_pass_rate: 0.70
Expand Down
79 changes: 23 additions & 56 deletions langtest/langtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import yaml
from pkg_resources import resource_filename

from langtest.utils.custom_types.sample import RuntimeSample
from .augmentation import AugmentRobustness, TemplaticAugment
from .datahandler.datasource import DataFactory, HuggingFaceDataset
from .modelhandler import LANGCHAIN_HUBS, ModelFactory
Expand Down Expand Up @@ -234,7 +233,6 @@ def __init__(

self._testcases = None
self._generated_results = None
self._runtime = RuntimeSample()
self.accuracy_results = None
self.min_pass_dict = None
self.default_min_pass_dict = None
Expand Down Expand Up @@ -322,16 +320,14 @@ def generate(self) -> "Harness":
]
else:
self._testcases = {}
self._runtime.transform_time = {}
for k, v in self.model.items():
_ = [
setattr(sample, "expected_results", v(sample.original))
for sample in m_data
]
(
self._testcases[k],
self._runtime.transform_time[k],
) = TestFactory.transform(self.task, self.data, tests, m_data=m_data)
(self._testcases[k]) = TestFactory.transform(
self.task, self.data, tests, m_data=m_data
)

return self

Expand All @@ -344,10 +340,7 @@ def generate(self) -> "Harness":
)
if len(tests.keys()) > 2:
tests = {k: v for k, v in tests.items() if k != "bias"}
(
other_testcases,
self._runtime.transform_time,
) = TestFactory.transform(
(other_testcases) = TestFactory.transform(
self.task, self.data, tests, m_data=m_data
)
self._testcases.extend(other_testcases)
Expand All @@ -358,13 +351,13 @@ def generate(self) -> "Harness":
)

else:
self._testcases, self._runtime.transform_time = TestFactory.transform(
self._testcases = TestFactory.transform(
self.task, self.data, tests, m_data=m_data
)

return self

self._testcases, self._runtime.transform_time = TestFactory.transform(
self._testcases = TestFactory.transform(
self.task, self.data, tests, m_data=m_data
)
return self
Expand All @@ -381,18 +374,18 @@ def run(self) -> "Harness":
"calling the `.run()` method."
)
if not isinstance(self._testcases, dict):
self._generated_results, self._runtime.run_time = TestFactory.run(
self._generated_results = TestFactory.run(
self._testcases,
self.model,
is_default=self.is_default,
raw_data=self.data,
**self._config.get("model_parameters", {}),
)

else:
self._generated_results = {}
self._runtime.run_time = {}
for k, v in self.model.items():
self._generated_results[k], self._runtime.run_time[k] = TestFactory.run(
self._generated_results[k] = TestFactory.run(
self._testcases[k],
v,
is_default=self.is_default,
Expand All @@ -404,16 +397,12 @@ def run(self) -> "Harness":

def report(
self,
return_runtime: bool = False,
unit: str = "ms",
format: str = "dataframe",
save_dir: str = None,
) -> pd.DataFrame:
"""Generate a report of the test results.
Args:
return_runtime (bool): whether to return runtime
unit (str): time unit to use
format (str): format in which to save the report
save_dir (str): name of the directory to save the file
Returns:
Expand Down Expand Up @@ -458,7 +447,7 @@ def report(
min_pass_rate = self.min_pass_dict.get(
test_type, multiple_perturbations_min_pass_rate
)
if summary[test_type]["category"] == "Accuracy":
if summary[test_type]["category"] in ["Accuracy", "performance"]:
min_pass_rate = 1

report[test_type] = {
Expand All @@ -485,10 +474,6 @@ def report(
df_report = df_report.reset_index(drop=True)

self.df_report = df_report.fillna("-")
if return_runtime:
self.df_report[f"time_elapsed ({unit})"] = self.df_report[
"test_type"
].apply(lambda x: self._runtime.total_time(unit)[x])

if format == "dataframe":
return self.df_report
Expand Down Expand Up @@ -529,7 +514,6 @@ def report(

else:
df_final_report = pd.DataFrame()
time_elapsed = {}
for k, v in self.model.items():
for sample in self._generated_results[k]:
summary[sample.test_type]["category"] = sample.category
Expand All @@ -543,7 +527,7 @@ def report(
test_type, self.default_min_pass_dict
)

if summary[test_type]["category"] == "Accuracy":
if summary[test_type]["category"] in ["Accuracy", "performance"]:
min_pass_rate = 1

report[test_type] = {
Expand All @@ -566,12 +550,6 @@ def report(
df_report = df_report.reset_index(drop=True)
df_report = df_report.fillna("-")

if return_runtime:
if k not in time_elapsed:
time_elapsed[k] = df_report["model_name"].apply(
lambda x: self._runtime.multi_model_total_time(unit)[x]
)

df_final_report = pd.concat([df_final_report, df_report])

df_final_report["minimum_pass_rate"] = (
Expand Down Expand Up @@ -605,16 +583,6 @@ def color_cells(series):
]

styled_df = pivot_df.style.apply(color_cells)
if return_runtime:
time_elapsed_mean = {k: v.mean() for k, v in time_elapsed.items()}
df_time_elapsed = pd.DataFrame(
list(time_elapsed_mean.items()),
columns=["model_name", f"time_elapsed ({unit})"],
)
df_time_elapsed.set_index("model_name", inplace=True)
from IPython.display import display

display(df_time_elapsed)

if format == "dataframe":
return styled_df
Expand Down Expand Up @@ -663,20 +631,19 @@ def generated_results(self) -> Optional[pd.DataFrame]:

if isinstance(self._generated_results, dict):
generated_results_df = []
if isinstance(self._generated_results, dict):
for k, v in self._generated_results.items():
model_generated_results_df = pd.DataFrame.from_dict(
[x.to_dict() for x in v]
for k, v in self._generated_results.items():
model_generated_results_df = pd.DataFrame.from_dict(
[x.to_dict() for x in v]
)
if (
"test_case" in model_generated_results_df.columns
and "original_question" in model_generated_results_df.columns
):
model_generated_results_df["original_question"].update(
model_generated_results_df.pop("test_case")
)
if (
"test_case" in model_generated_results_df.columns
and "original_question" in model_generated_results_df.columns
):
model_generated_results_df["original_question"].update(
model_generated_results_df.pop("test_case")
)
model_generated_results_df["model_name"] = k
generated_results_df.append(model_generated_results_df)
model_generated_results_df["model_name"] = k
generated_results_df.append(model_generated_results_df)
generated_results_df = pd.concat(generated_results_df).reset_index(drop=True)

else:
Expand Down
Loading

0 comments on commit bd45e29

Please sign in to comment.