Skip to content

Commit

Permalink
Merge pull request #881 from lbushi25/clamp_input_generation
Browse files Browse the repository at this point in the history
Generate clamp inputs according to precondition in the spec
  • Loading branch information
steffenlarsen committed Apr 30, 2024
2 parents 9ea1b1c + 98fc758 commit 7a206c0
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 37 deletions.
100 changes: 63 additions & 37 deletions tests/math_builtin_api/modules/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,57 +77,62 @@
""")
}

def generate_value(base_type, dim):
val = ""
for i in range(dim):
if base_type == "bool":
val += "true,"
if base_type == "float" or base_type == "double" or base_type == "sycl::half":
# 10 digits of precision for floats, doubles and half.
val += str(round(random.uniform(0.1, 0.9), 10))
if base_type == "double":
val += ","
else:
val += "f,"
# random 8 bit integer
if base_type == "char":
val += str(random.randint(0, 127)) + ","
if base_type == "signed char" or base_type == "int8_t":
val += str(random.randint(-128, 127)) + ","
if base_type == "unsigned char" or base_type == "uint8_t":
val += str(random.randint(0, 255)) + ","
def get_literal_suffix(base_type):
mapping = {
"float": "f", "unsigned long": "U", "uint32_t": "U",
"sycl::half": "f", "long long": "LL", "int64_t": "LL",
"unsigned long long": "LLU", "uint64_t": "LLU" }
return mapping[base_type] if base_type in mapping else ""

def generate_literal_value(base_type):
if base_type == "bool":
return "true"
if base_type == "float" or base_type == "double" or base_type == "sycl::half":
# 10 digits of precision for floats, doubles and half.
return round(random.uniform(0.1, 0.9), 10)
# random 8 bit integer
if base_type == "char":
return random.randint(0, 127)
if base_type == "signed char" or base_type == "int8_t":
return random.randint(-128, 127)
if base_type == "unsigned char" or base_type == "uint8_t":
return random.randint(0, 255)
# random 16 bit integer
if base_type == "int" or base_type == "short" or base_type == "int16_t":
val += str(random.randint(-32768, 32767)) + ","
if base_type == "unsigned" or base_type == "unsigned short" or base_type == "uint16_t":
val += str(random.randint(0, 65535)) + ","
if base_type == "int" or base_type == "short" or base_type == "int16_t":
return random.randint(-32768, 32767)
if base_type == "unsigned" or base_type == "unsigned short" or base_type == "uint16_t":
return random.randint(0, 65535)
# random 32 bit integer
if base_type == "long" or base_type == "int32_t":
val += str(random.randint(-2147483648, 2147483647)) + ","
if base_type == "unsigned long" or base_type == "uint32_t":
val += str(random.randint(0, 4294967295)) + "U" + ","
if base_type == "long" or base_type == "int32_t":
return random.randint(-2147483648, 2147483647)
if base_type == "unsigned long" or base_type == "uint32_t":
return random.randint(0, 4294967295)
# random 64 bit integer
if base_type == "long long" or base_type == "int64_t":
val += str(random.randint(-9223372036854775808, 9223372036854775807)) + "LL" + ","
if base_type == "unsigned long long" or base_type == "uint64_t":
val += str(random.randint(0, 18446744073709551615)) + "LLU" + ","
return val[:-1]
if base_type == "long long" or base_type == "int64_t":
return random.randint(-9223372036854775808, 9223372036854775807)
if base_type == "unsigned long long" or base_type == "uint64_t":
return random.randint(0, 18446744073709551615)

def generate_value(base_type, dim):
values = [str(generate_literal_value(base_type)) + get_literal_suffix(base_type) for _ in range(dim)]
return ','.join(values)

def generate_multi_ptr(var_name, var_type, memory, decorated):
decl = ""
value = generate_value(var_type.base_type, var_type.dim)
if memory == "global":
source_name = "multiPtrSourceData"
decl = var_type.name + " " + source_name + "(" + generate_value(var_type.base_type, var_type.dim) + ");\n"
decl = var_type.name + " " + source_name + "(" + value + ");\n"
decl += "sycl::multi_ptr<" + var_type.name + ", sycl::access::address_space::global_space," + decorated + "> "
decl += var_name + "(acc);\n"
if memory == "local":
source_name = "multiPtrSourceData"
decl = var_type.name + " " + source_name + "(" + generate_value(var_type.base_type, var_type.dim) + ");\n"
decl = var_type.name + " " + source_name + "(" + value + ");\n"
decl += "sycl::multi_ptr<" + var_type.name + ", sycl::access::address_space::local_space," + decorated + "> "
decl += var_name + "(acc);\n"
if memory == "private":
source_name = "multiPtrSourceData"
decl = var_type.name + " " + source_name + "(" + generate_value(var_type.base_type, var_type.dim) + ");\n"
decl = var_type.name + " " + source_name + "(" + value + ");\n"
decl += "sycl::multi_ptr<" + var_type.name + ", sycl::access::address_space::private_space," + decorated + "> "
decl += var_name + " = sycl::address_space_cast<sycl::access::address_space::private_space," + decorated + ">(&"
decl += source_name + ");\n"
Expand All @@ -136,6 +141,26 @@ def generate_multi_ptr(var_name, var_type, memory, decorated):
def generate_variable(var_name, var_type, var_index):
return var_type.name + " " + var_name + "(" + generate_value(var_type.base_type, var_type.dim) + ");\n"

# argument generator for clamp which makes sure that its third argument is at least equal to its second argument in every dimension.
def generate_arguments_clamp(sig):
arg_types = sig.arg_types
arg_names = ["inputData_" + str(i) for i in range(3)]
arg0 = [str(generate_literal_value(arg_types[0].base_type)) + get_literal_suffix(arg_types[0].base_type) for _ in range(arg_types[0].dim)]
arg1 = [generate_literal_value(arg_types[1].base_type) for _ in range(arg_types[1].dim)]
arg2 = [generate_literal_value(arg_types[2].base_type) for _ in range(arg_types[2].dim)]

# clamp requires that minval (arg1) <= maxval (arg2)
for i in range(arg_types[1].dim):
if arg1[i] > arg2[i]:
arg1[i], arg2[i] = arg2[i], arg1[i] # swap

arg1 = [str(x) + get_literal_suffix(arg_types[1].base_type) for x in arg1]
arg2 = [str(x) + get_literal_suffix(arg_types[2].base_type) for x in arg2]
arg_vals = [arg0, arg1, arg2]
args = [arg_types[i].name + " " + arg_names[i] + "(" + ",".join(arg_vals[i]) + ");\n" for i in range(3)]
return (arg_names, " ".join(args))


def generate_arguments(sig, memory, decorated):
arg_src = ""
arg_names = []
Expand All @@ -156,7 +181,6 @@ def generate_arguments(sig, memory, decorated):
current_arg = generate_multi_ptr(arg_name, arg, memory, decorated )
else:
current_arg = generate_variable(arg_name, arg, arg_index)

arg_src += current_arg + " "
arg_index += 1
return (arg_names, arg_src)
Expand Down Expand Up @@ -283,7 +307,9 @@ def generate_reference_ptr(types, sig, arg_names, arg_src):
def generate_test_case(test_id, types, sig, memory, check, decorated = ""):
testCaseSource = test_case_templates_check[memory] if check else test_case_templates[memory]
testCaseId = str(test_id)
(arg_names, arg_src) = generate_arguments(sig, memory, decorated)
# for the clamp function we use a separate argument generator to make sure that its preconditions are met,
# otherwise argument generation for clamp would be completely random.
(arg_names, arg_src) = generate_arguments(sig, memory, decorated) if sig.name != "clamp" else generate_arguments_clamp(sig)
testCaseSource = testCaseSource.replace("$REFERENCE", generate_reference(sig, arg_names, arg_src))
testCaseSource = testCaseSource.replace("$PTR_REF", generate_reference_ptr(types, sig, arg_names, arg_src))
testCaseSource = testCaseSource.replace("$TEST_ID", testCaseId)
Expand Down
1 change: 1 addition & 0 deletions util/math_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ run_func_on_vector_result_ref(funT fun, Args... args) {
sycl::vec<T, N> res;
std::map<int, bool> undefined;
for (int i = 0; i < N; i++) {
undefined[i] = false;
resultRef<T> element = fun(getElement(args, i)...);
if (element.undefined.empty())
setElement<T, N>(res, i, element.res);
Expand Down

0 comments on commit 7a206c0

Please sign in to comment.