Skip to content

Commit

Permalink
Make math_builtin_api tests splittable into smaller chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
jzc committed Apr 30, 2024
1 parent 7a206c0 commit 56b62ab
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 39 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
build/
build*/
.idea
log.txt
*.pyc
45 changes: 28 additions & 17 deletions tests/math_builtin_api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,44 @@ set(math_builtin_depends
"modules/test_generator.py"
)

set(SYCL_CTS_MATH_BUILTIN_N_SPLITS 16 CACHE STRING
"The number of times to divide each math_builtin_api test")

foreach(cat ${MATH_CAT_WITH_VARIANT})
foreach(var ${MATH_VARIANT})
if ("${cat}" STREQUAL geometric AND "${var}" STREQUAL half)
continue()
endif()
foreach(split_index RANGE 1 ${SYCL_CTS_MATH_BUILTIN_N_SPLITS})
if ("${cat}" STREQUAL geometric AND "${var}" STREQUAL half)
continue()
endif()
# Invoke our generator
# the path to the generated cpp file will be added to TEST_CASES_LIST
math(EXPR split_index_0 "${split_index} - 1")
generate_cts_test(TESTS TEST_CASES_LIST
GENERATOR "generate_math_builtin.py"
OUTPUT "math_builtin_${cat}_${var}_${split_index}_${SYCL_CTS_MATH_BUILTIN_N_SPLITS}.cpp"
INPUT "math_builtin.template"
EXTRA_ARGS -test ${cat} -variante ${var} -marray true
-i ${split_index_0} -n ${SYCL_CTS_MATH_BUILTIN_N_SPLITS}
DEPENDS ${math_builtin_depends}
)
endforeach()
endforeach()
endforeach()

foreach(cat ${MATH_CAT})
foreach(split_index RANGE 1 ${SYCL_CTS_MATH_BUILTIN_N_SPLITS})
# Invoke our generator
# the path to the generated cpp file will be added to TEST_CASES_LIST
math(EXPR split_index_0 "${split_index} - 1")
generate_cts_test(TESTS TEST_CASES_LIST
GENERATOR "generate_math_builtin.py"
OUTPUT "math_builtin_${cat}_${var}.cpp"
OUTPUT "math_builtin_${cat}_${split_index}_${SYCL_CTS_MATH_BUILTIN_N_SPLITS}.cpp"
INPUT "math_builtin.template"
EXTRA_ARGS -test ${cat} -variante ${var} -marray true
EXTRA_ARGS -test ${cat} -marray true
-i ${split_index_0} -n ${SYCL_CTS_MATH_BUILTIN_N_SPLITS}
DEPENDS ${math_builtin_depends}
)
endforeach()
endforeach()

foreach(cat ${MATH_CAT})
# Invoke our generator
# the path to the generated cpp file will be added to TEST_CASES_LIST
generate_cts_test(TESTS TEST_CASES_LIST
GENERATOR "generate_math_builtin.py"
OUTPUT "math_builtin_${cat}.cpp"
INPUT "math_builtin.template"
EXTRA_ARGS -test ${cat} -marray true
DEPENDS ${math_builtin_depends}
)
endforeach()

add_cts_test(${TEST_CASES_LIST})
61 changes: 50 additions & 11 deletions tests/math_builtin_api/generate_math_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,18 @@ def write_cases_to_file(generated_test_cases, inputFile, outputFile, extension=N
with open(outputFile, 'w+') as output:
output.write(newSource)

def create_tests(test_id, types, signatures, kind, template, file_name, check = False):
def get_split(lst, split_index, n_splits):
if split_index is None:
return lst
split_size = len(lst) // n_splits
if split_index + 1 == n_splits:
split = lst[split_size * split_index:]
else:
split = lst[split_size * split_index:split_size * (split_index+1)]
assert len(split) >= 1, 'Bad split!'
return split

def create_tests(test_id, types, signatures, kind, template, file_name, split_index, n_splits, check = False):
expanded_signatures = test_generator.expand_signatures(types, signatures)

# Extensions should be placed on separate files.
Expand All @@ -92,13 +103,16 @@ def create_tests(test_id, types, signatures, kind, template, file_name, check =

if base_signatures and kind == 'base':
generated_base_test_cases = test_generator.generate_test_cases(test_id, types, base_signatures, check)
write_cases_to_file(generated_base_test_cases, template, file_name)
test_cases = get_split(generated_base_test_cases, split_index, n_splits)
write_cases_to_file("".join(test_cases), template, file_name)
elif half_signatures and kind == 'half':
generated_half_test_cases = test_generator.generate_test_cases(test_id + 300000, types, half_signatures, check)
write_cases_to_file(generated_half_test_cases, template, file_name, "fp16")
test_cases = get_split(generated_half_test_cases, split_index, n_splits)
write_cases_to_file("".join(test_cases), template, file_name, "fp16")
elif double_signatures and kind == 'double':
generated_double_test_cases = test_generator.generate_test_cases(test_id + 600000, types, double_signatures, check)
write_cases_to_file(generated_double_test_cases, template, file_name, "fp64")
test_cases = get_split(generated_double_test_cases, split_index, n_splits)
write_cases_to_file("".join(test_cases), template, file_name, "fp64")
else:
print("No %s overloads to generate for the test category" % kind)
sys.exit(1)
Expand Down Expand Up @@ -132,8 +146,26 @@ def main():
required=True,
metavar='<out file>',
help='CTS test output')
argparser.add_argument(
'-i',
dest='split_index',
help='Specifies which split to generate when splitting',
type=int
)
argparser.add_argument(
'-n',
dest='n_splits',
help='Splits the generated test cases into n different files',
type=int
)
args = argparser.parse_args()

if args.n_splits is not None or args.split_index is not None:
assert args.n_splits is not None and args.split_index is not None, \
"both -i and -n are necessary if one is specified"
assert 0 <= args.split_index < args.n_splits, \
"the value passed to -i is out of bounds"

use_marray = (args.marray == 'true')
run = runner(use_marray)
if not use_marray:
Expand All @@ -147,31 +179,38 @@ def main():

if args.test == 'integer':
integer_signatures = sycl_functions.create_integer_signatures()
create_tests(0, expanded_types, integer_signatures, args.variante, args.template, args.output, verifyResults)
create_tests(0, expanded_types, integer_signatures, args.variante, args.template, args.output,
args.split_index, args.n_splits, verifyResults)

if args.test == 'common':
common_signatures = sycl_functions.create_common_signatures()
create_tests(1000000, expanded_types, common_signatures, args.variante, args.template, args.output, verifyResults)
create_tests(1000000, expanded_types, common_signatures, args.variante, args.template, args.output,
args.split_index, args.n_splits, verifyResults)

if args.test == 'geometric':
geomteric_signatures = sycl_functions.create_geometric_signatures()
create_tests(2000000, expanded_types, geomteric_signatures, args.variante, args.template, args.output, verifyResults)
create_tests(2000000, expanded_types, geomteric_signatures, args.variante, args.template, args.output,
args.split_index, args.n_splits, verifyResults)

if args.test == 'relational':
relational_signatures = sycl_functions.create_relational_signatures()
create_tests(3000000, expanded_types, relational_signatures, args.variante, args.template, args.output, verifyResults)
create_tests(3000000, expanded_types, relational_signatures, args.variante, args.template, args.output,
args.split_index, args.n_splits, verifyResults)

if args.test == 'float':
float_signatures = sycl_functions.create_float_signatures()
create_tests(4000000, expanded_types, float_signatures, args.variante, args.template, args.output, verifyResults)
create_tests(4000000, expanded_types, float_signatures, args.variante, args.template, args.output,
args.split_index, args.n_splits, verifyResults)

if args.test == 'native':
native_signatures = sycl_functions.create_native_signatures()
create_tests(5000000, expanded_types, native_signatures, args.variante, args.template, args.output, verifyResults)
create_tests(5000000, expanded_types, native_signatures, args.variante, args.template, args.output,
args.split_index, args.n_splits, verifyResults)

if args.test == 'half':
half_signatures = sycl_functions.create_half_signatures()
create_tests(6000000, expanded_types, half_signatures, args.variante, args.template, args.output, verifyResults)
create_tests(6000000, expanded_types, half_signatures, args.variante, args.template, args.output,
args.split_index, args.n_splits, verifyResults)

if __name__ == "__main__":
main()
20 changes: 10 additions & 10 deletions tests/math_builtin_api/modules/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,31 +347,31 @@ def generate_test_case(test_id, types, sig, memory, check, decorated = ""):

def generate_test_cases(test_id, types, sig_list, check):
random.seed(0)
test_source = ""
test_cases = []
decorated_yes = "sycl::access::decorated::yes"
decorated_no = "sycl::access::decorated::no"
for sig in sig_list:
if sig.pntr_indx:#If the signature contains a pointer argument.
test_source += generate_test_case(test_id, types, sig, "private", check, decorated_no)
test_cases.append(generate_test_case(test_id, types, sig, "private", check, decorated_no))
test_id += 1
test_source += generate_test_case(test_id, types, sig, "private", check, decorated_yes)
test_cases.append(generate_test_case(test_id, types, sig, "private", check, decorated_yes))
test_id += 1
test_source += generate_test_case(test_id, types, sig, "local", check, decorated_no)
test_cases.append(generate_test_case(test_id, types, sig, "local", check, decorated_no))
test_id += 1
test_source += generate_test_case(test_id, types, sig, "local", check, decorated_yes)
test_cases.append(generate_test_case(test_id, types, sig, "local", check, decorated_yes))
test_id += 1
test_source += generate_test_case(test_id, types, sig, "global", check, decorated_no)
test_cases.append(generate_test_case(test_id, types, sig, "global", check, decorated_no))
test_id += 1
test_source += generate_test_case(test_id, types, sig, "global", check, decorated_yes)
test_cases.append(generate_test_case(test_id, types, sig, "global", check, decorated_yes))
test_id += 1
else:
if check:
test_source += generate_test_case(test_id, types, sig, "no_ptr", check)
test_cases.append(generate_test_case(test_id, types, sig, "no_ptr", check))
test_id += 1
else:
test_source += generate_test_case(test_id, types, sig, "private", check)
test_cases.append(generate_test_case(test_id, types, sig, "private", check))
test_id += 1
return test_source
return test_cases

# Lists of the types with equal sizes
chars = ["char", "signed char", "unsigned char"]
Expand Down

0 comments on commit 56b62ab

Please sign in to comment.