-
Notifications
You must be signed in to change notification settings - Fork 196
/
test_main.py
116 lines (95 loc) Β· 4.3 KB
/
test_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import pytest
from initialize import initialize_models
from interfaces.QuestionAnswerOperation import QuestionAnswerOperation
from interfaces.SentenceOperation import (
SentenceAndTargetOperation,
SentenceOperation,
)
from interfaces.TaggingOperation import TaggingOperation
from TestRunner import OperationRuns
def get_assert_message(transformation, expected_output, predicted_output):
transformation_name = transformation.__class__.__name__
return (
f"Mis-match in expected and predicted output for {transformation_name} transformation: \n "
f"Expected Output: {expected_output} \n "
f"Predicted Output: {predicted_output}"
)
def execute_sentence_operation_test_case(transformation, test):
filter_args = test["inputs"]
outputs = test["outputs"]
perturbs = transformation.generate(**filter_args)
for pred_output, output in zip(perturbs, outputs):
assert pred_output == output["sentence"], get_assert_message(
transformation, output["sentence"], pred_output
)
def execute_sentence_target_operation_test_case(transformation, test):
filter_args = test["inputs"]
outputs = test["outputs"]
perturbs = transformation.generate(**filter_args)
for idx, (sentence, target) in enumerate(perturbs):
assert sentence == outputs[idx]["sentence"], get_assert_message(
transformation, outputs[idx]["sentence"], sentence
)
assert target == outputs[idx]["target"], get_assert_message(
transformation, outputs[idx]["target"], target
)
def execute_ques_ans_test_case(transformation, test):
filter_args = test["inputs"]
outputs = test["outputs"]
perturbs = transformation.generate(**filter_args)
for idx, (context, question, answers) in enumerate(perturbs):
assert context == outputs[idx]["context"], get_assert_message(
transformation, outputs[idx]["context"], context
)
assert question == outputs[idx]["question"], get_assert_message(
transformation, outputs[idx]["question"], question
)
assert answers == outputs[idx]["answers"], get_assert_message(
transformation, outputs[idx]["answers"], answers
)
def execute_tagging_test_case(transformation, test):
filter_args = test["inputs"]
token_sequence = filter_args["token_sequence"]
tag_sequence = filter_args["tag_sequence"]
outputs = test["outputs"]
perturbs = transformation.generate(
token_sequence.split(), tag_sequence.split()
)
for idx, (p_tokens, p_tags) in enumerate(perturbs):
expected_tokens = outputs[idx]["token_sequence"].split()
expected_tags = outputs[idx]["tag_sequence"].split()
assert p_tokens == expected_tokens, get_assert_message(
transformation, expected_tokens, p_tokens
)
assert p_tags == expected_tags, get_assert_message(
transformation, expected_tags, p_tags
)
def execute_test_case_for_transformation(transformation_name):
tx = OperationRuns(transformation_name)
for transformation, test in zip(tx.operations, tx.operation_test_cases):
if isinstance(transformation, SentenceOperation):
execute_sentence_operation_test_case(transformation, test)
elif isinstance(transformation, SentenceAndTargetOperation):
execute_sentence_target_operation_test_case(transformation, test)
elif isinstance(transformation, QuestionAnswerOperation):
execute_ques_ans_test_case(transformation, test)
elif isinstance(transformation, TaggingOperation):
execute_tagging_test_case(transformation, test)
else:
print(f"Invalid transformation type: {transformation}")
def execute_test_case_for_filter(filter_name):
tx = OperationRuns(filter_name, "filters")
for filter, test in zip(tx.operations, tx.operation_test_cases):
filter_args = test["inputs"]
output = filter.filter(**filter_args)
assert (
output == test["outputs"]
), f"The filter should return {test['outputs']}"
def test_operation(transformation_name, filter_name):
initialize_models()
execute_test_case_for_transformation(transformation_name)
execute_test_case_for_filter(filter_name)
def main():
pytest.main()
if __name__ == "__main__":
main()