@@ -0,0 +1,49 @@
"""Test code for yolo op"""
import logging
import numpy as np
import tvm
import topi
import topi.testing
from topi.util import get_const_tuple

def verify_yolo(ishape, n, classes):
'''Verify yolo operator by comparing outputs from tvm and numpy implementation'''

A = tvm.placeholder(ishape, name='A')
B = topi.cpp.yolo.yolo(A, n, classes)
dtype = A.dtype

def get_ref_data_yolo():
'''Randomly initialize the data variables and get refernce output for the yolo operation'''
a_np = np.random.uniform(size=ishape).astype(dtype)
b_np = topi.testing.yolo_python(a_np, n, classes)
return a_np, b_np

a_np, b_np = get_ref_data_yolo()
def check_device(device):
'''Check the device is available and if so, build and run the program'''
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
if device == "llvm":
s = topi.cpp.generic.default_schedule(target, [B], False)
else:
s = topi.cpp.cuda.schedule_injective(target, [B])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A, B], device, name="yolo")
func(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)

for device in ['cuda', 'opencl', 'metal', 'rocm', 'llvm', 'vulkan']:
check_device(device)

def test_yolo():
verify_yolo((1, 425, 19, 19), 5, 80)

if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
test_yolo()