The result sampling dict is always {9: 12, 17: 11, 8: 16, 2: 14, -1: 10, 6: 15, 5: 11, -2: 11}
def test_sample_compute_location():
n = 100
sch = tir.Schedule(tiled_conv2d_with_padding, seed=42, debug_mask="all")
pad_input = sch.get_block("PadInput")
decision_dict = dict()
for _ in range(n):
_ = sch.sample_compute_location(pad_input) # pylint: disable=invalid-name
decision = sch.trace.decisions[sch.trace.insts[-1]]
decision_dict[decision] = decision_dict[decision] + 1 if decision in decision_dict else 1
n_candidates = 8
expected_rate = 1.0 / n_candidates
for _, cnt in decision_dict.items():
assert (expected_rate - 0.03) * n <= cnt <= (expected_rate + 0.03) * n
E assert 16 <= ((0.125 + 0.03) * 100)
tests/python/unittest/test_tir_schedule_sampling.py:193: AssertionError
============================================================================================== short test summary info ===============================================================================================
FAILED tests/python/unittest/test_tir_schedule_sampling.py::test_sample_compute_location - assert 16 <= ((0.125 + 0.03) * 100)