diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index dbf00ebeb52b..b8740f811ff7 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -15,6 +15,8 @@ from .pooling import schedule_pool, schedule_global_pool from .conv2d_transpose_nchw import schedule_conv2d_transpose_nchw from .extern import schedule_extern -from .vision import schedule_region -from .vision import schedule_reorg from .nn import schedule_lrn, schedule_l2_normalize +from .vision import * +from . import ssd +from .ssd import * +from .nms import * diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py new file mode 100644 index 000000000000..4d4e402de5c2 --- /dev/null +++ b/topi/python/topi/cuda/nms.py @@ -0,0 +1,354 @@ +# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison +"""Non-maximum suppression operator""" +import math +import tvm + +from tvm import api +from topi.vision import nms + + +def sort_ir(data, index, output, axis, is_descend): + """Low level IR to do sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. + + Parameters + ---------- + data: Buffer + 2D Buffer of input boxes' score with shape [batch_size, num_anchors]. + + index : Buffer + Buffer of number of valid number of boxes. + + output : Buffer + Output buffer of indicies of sorted tensor. + + axis : int + The axis used for sorting. + + is_descend : bool + If the sorted data is in descending order. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + + max_threads = int( + tvm.target.current_target(allow_none=False).max_num_threads) + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("blockIdx.x") + ib = tvm.ir_builder.create() + p_data = ib.buffer_ptr(data) + p_index = ib.buffer_ptr(index) + p_out = ib.buffer_ptr(output) + ndim = len(data.shape) + assert data.dtype == "float32", "Currently only supports input dtype to be float32" + assert axis < ndim, "Axis out of boundary for input ndim %d" % ndim + + axis_mul_before = 1 + axis_mul_after = 1 + if axis < 0: + axis = ndim + axis + for i in range(0, ndim): + if i < axis: + axis_mul_before *= data.shape[i] + elif i > axis: + axis_mul_after *= data.shape[i] + + dshape = 0 + for i in range(0, len(index.shape)): + dshape += index.shape[i] + dshape = tvm.select(dshape > axis_mul_before*axis_mul_after, dshape, + axis_mul_before*axis_mul_after) + + sizes_temp = ib.allocate( + "int32", dshape, name="sizes_temp", scope="global") + sizes = ib.allocate("int32", dshape, name="sizes", scope="global") + temp_index = ib.allocate("int32", dshape, name="temp_index", scope="local") + temp_data = ib.allocate("float32", dshape, name="temp_data", scope="local") + data_new = ib.allocate("float32", dshape, name="data_new", scope="global") + index_new = ib.allocate("int32", dshape, name="index_new", scope="global") + nthread_tx = max_threads + nthread_bx = dshape // max_threads + 1 + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + + with ib.if_scope(tid < axis_mul_before * axis_mul_after): + sizes[tid] = p_index[tid] + sizes_temp[tid] = p_index[tid] + + with ib.if_scope(tid < axis_mul_before * axis_mul_after): + with ib.for_range(0, tvm.floor(tvm.sqrt((axis_mul_before * axis_mul_after) \ + .astype("float32"))) + 1, name="k") as k: + with ib.if_scope(tid - (tvm.const(1, "int32") << k) >= 0): + with ib.if_scope(k % 2 == 0): + sizes[tid] += sizes_temp[tid - ( + tvm.const(1, "int32") << k)] + sizes_temp[tid] = sizes[tid] + with ib.else_scope(): + sizes_temp[tid] += sizes[tid - ( + tvm.const(1, "int32") << k)] + sizes[tid] = sizes_temp[tid] + + with ib.if_scope(tid < axis_mul_before * axis_mul_after): + i = tid / axis_mul_after + j = tid % axis_mul_after + current_sort_num = p_index[tid] + base_idx = i * data.shape[axis] * axis_mul_after + j + with ib.for_range(0, current_sort_num, name="k") as k: + full_idx = base_idx + k * axis_mul_after + with ib.if_scope(tid == 0): + start = 0 + with ib.else_scope(): + start = sizes[tid-1] + index_new[start + k] = k + data_new[start + k] = p_data[full_idx] + + with ib.if_scope(tid < axis_mul_before * axis_mul_after): + with ib.if_scope(tid == 0): + start = 0 + with ib.else_scope(): + start = sizes[tid-1] + # OddEvenTransposeSort + with ib.for_range(0, p_index[tid], name="k") as k: + with ib.for_range(0, p_index[tid] - 1, name="i") as i: + with ib.if_scope(i % 2 == (k & 1)): + with ib.if_scope(((data_new[i+start] < data_new[i+start+1]) ^ + is_descend) == False): + temp_data[tid] = data_new[i+start] + data_new[i+start] = data_new[i+start+1] + data_new[i+start+1] = temp_data[tid] + temp_index[tid] = index_new[i+start] + index_new[i+start] = index_new[i+start+1] + index_new[i+start+1] = temp_index[tid] + + with ib.if_scope(tid < axis_mul_before * axis_mul_after): + i = tid / axis_mul_after + j = tid % axis_mul_after + current_sort_num = p_index[tid] + base_idx = i * data.shape[axis] * axis_mul_after + j + with ib.for_range(0, data.shape[axis], name="k") as k: + with ib.if_scope(tid == 0): + start = 0 + with ib.else_scope(): + start = sizes[tid-1] + p_out[base_idx + k * axis_mul_after] = tvm.select( + k < current_sort_num, + index_new[k+start], k) + body = ib.get() + return body + + +def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, nms_topk): + """Low level IR routing for transform location in multibox_detection operator. + + Parameters + ---------- + data: Buffer + Buffer of output boxes with class and score. + + sort_result : Buffer + Buffer of output box indexes sorted by score. + + valid_count : Buffer + Buffer of number of valid output boxes. + + out : Buffer + Output buffer. + + nms_threshold : float + Non-maximum suppression threshold. + + force_suppress : boolean + Whether to suppress all detections regardless of class_id. + + nms_topk : int + Keep maximum top k detections before nms, -1 for no limit. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + def calculate_overlap(out_tensor, box_a_idx, box_b_idx): + """Calculate overlap of two boxes. + """ + w = tvm.make.Max(0.0, tvm.make.Min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2]) + - tvm.make.Max(out_tensor[box_a_idx], out_tensor[box_b_idx])) + h = tvm.make.Max(0.0, tvm.make.Min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3]) + - tvm.make.Max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1])) + i = w * h + u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx]) * \ + (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \ + (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \ + (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i + return tvm.select(u <= 0.0, 0.0, i / u) + + max_threads = int(math.sqrt( + tvm.target.current_target(allow_none=False).max_num_threads)) + tx = tvm.thread_axis("threadIdx.x") + ty = tvm.thread_axis("threadIdx.y") + bx = tvm.thread_axis("blockIdx.x") + by = tvm.thread_axis("blockIdx.y") + ib = tvm.ir_builder.create() + p_data = ib.buffer_ptr(data) + p_sort_result = ib.buffer_ptr(sort_result) + p_valid_count = ib.buffer_ptr(valid_count) + p_out = ib.buffer_ptr(out) + batch_size = out.shape[0] + num_anchors = out.shape[1] + nthread_tx = max_threads + nthread_bx = num_anchors // max_threads + 1 + nthread_ty = max_threads + nthread_by = 6 // max_threads + 1 + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(ty, "thread_extent", nthread_ty) + ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(by, "thread_extent", nthread_by) + i = bx * max_threads + tx + j = by * max_threads + ty + + nms_threshold_node = tvm.make.node( + "FloatImm", dtype="float32", value=nms_threshold) + nms_topk_node = tvm.make.node("IntImm", dtype="int32", value=nms_topk) + force_suppress_node = tvm.make.node( + "IntImm", dtype="int32", value=1 if force_suppress else 0) + with ib.for_range(0, batch_size, for_type="unroll", name="n") as n: + with ib.if_scope( + tvm.all(nms_threshold_node > 0, nms_threshold_node < 1, + p_valid_count[0] > 0)): + # Reorder output + nkeep = tvm.select( + tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n]), + nms_topk, p_valid_count[n]) + with ib.if_scope(i < nkeep): + with ib.if_scope(j < 6): + p_out[(n * num_anchors * 6 + + i * 6 + j)] = p_data[(n * num_anchors * 6 + + p_sort_result[n * num_anchors + i] * 6 + j)] + with ib.if_scope(tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n])): + with ib.if_scope(i < p_valid_count[n] - nkeep): + with ib.if_scope(j < 6): + p_out[(n * num_anchors * 6 + + (i + nkeep) * 6 + j)] = p_data[(n * num_anchors * 6 + + (i + nkeep) * 6 + j)] + # Apply nms + with ib.if_scope(i < p_valid_count[n]): + offset_i = i * 6 + with ib.if_scope(p_out[n * num_anchors * 6 + offset_i] >= 0): + with ib.if_scope(j < p_valid_count[n]): + offset_j = j * 6 + with ib.if_scope(tvm.all(j > i, p_out[n * num_anchors * 6 + + offset_j] >= 0)): + with ib.if_scope(tvm.any(force_suppress_node > 0, + p_out[n * num_anchors * 6 + offset_i] == + p_out[n * num_anchors * 6 + offset_j])): + # When force_suppress == True or class_id equals + iou = calculate_overlap( + p_out, n * num_anchors * 6 + offset_i + 2, + n * num_anchors * 6 + offset_j + 2) + with ib.if_scope(iou >= nms_threshold): + p_out[ + n * num_anchors * 6 + offset_j] = -1.0 + with ib.else_scope(): + with ib.if_scope(i < p_valid_count[n]): + with ib.if_scope(j < 6): + p_out[(n * num_anchors * 6 + + i * 6 + j)] = p_data[n * num_anchors * 6 + i * 6 + j] + # Set invalid entry to be -1 + with ib.if_scope(i < num_anchors - p_valid_count[n]): + with ib.if_scope(j < 6): + p_out[n * num_anchors * 6 + (i + + p_valid_count[n]) * 6 + j] = -1.0 + body = ib.get() + return body + + +@nms.register(["cuda", "gpu"]) +def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1): + """Non-maximum suppression operator for object detection. + + Parameters + ---------- + data: tvm.Tensor + 3-D tensor with shape [batch_size, num_anchors, 6]. + The last dimension should be in format of + [class_id, score, box_left, box_top, box_right, box_bottom]. + + valid_count : tvm.Tensor + 1-D tensor for valid number of boxes. + + nms_threshold : float + Non-maximum suppression threshold. + + force_suppress : boolean + Whether to suppress all detections regardless of class_id. + + nms_topk : int + Keep maximum top k detections before nms, -1 for no limit. + + Returns + ------- + out : tvm.Tensor + 3-D tensor with shape [batch_size, num_anchors, 6]. + + Example + -------- + .. code-block:: python + + # An example to use nms + dshape = (1, 5, 6) + data = tvm.placeholder(dshape, name="data") + valid_count = tvm.placeholder( + (dshape[0],), dtype="int32", name="valid_count") + nms_threshold = 0.7 + force_suppress = True + nms_topk = -1 + out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk) + np_data = np.random.uniform(dshape) + np_valid_count = np.array([4]) + s = topi.generic.schedule_nms(out) + f = tvm.build(s, [data, valid_count, out], "llvm") + ctx = tvm.cpu() + tvm_data = tvm.nd.array(np_data, ctx) + tvm_valid_count = tvm.nd.array(np_valid_count, ctx) + tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx) + f(tvm_data, tvm_valid_count, tvm_out) + """ + batch_size = data.shape[0] + num_anchors = data.shape[1] + valid_count_dtype = "int32" + valid_count_buf = api.decl_buffer(valid_count.shape, valid_count_dtype, + "valid_count_buf", data_alignment=4) + data_buf = api.decl_buffer( + data.shape, data.dtype, "data_buf", data_alignment=8) + score_axis = 1 + score_shape = (batch_size, num_anchors) + score_tensor = tvm.compute( + score_shape, lambda i, j: data[i, j, score_axis], name="score_tensor") + score_tensor_buf = api.decl_buffer(score_tensor.shape, data.dtype, + "score_tensor_buf", data_alignment=8) + sort_tensor_dtype = "int32" + sort_tensor_buf = api.decl_buffer(score_shape, sort_tensor_dtype, + "sort_tensor_buf", data_alignment=8) + + sort_tensor = \ + tvm.extern(score_shape, + [score_tensor, valid_count], + lambda ins, outs: sort_ir( + ins[0], ins[1], outs[0], score_axis, True), + dtype=sort_tensor_dtype, + in_buffers=[score_tensor_buf, valid_count_buf], + out_buffers=sort_tensor_buf, + name="nms_sort") + out = \ + tvm.extern(data.shape, + [data, sort_tensor, valid_count], + lambda ins, outs: nms_ir( + ins[0], ins[1], ins[2], outs[0], nms_threshold, + force_suppress, nms_topk), + dtype="float32", + in_buffers=[data_buf, sort_tensor_buf, valid_count_buf], + tag="nms") + return out diff --git a/topi/python/topi/cuda/ssd/__init__.py b/topi/python/topi/cuda/ssd/__init__.py new file mode 100644 index 000000000000..d680c578e7aa --- /dev/null +++ b/topi/python/topi/cuda/ssd/__init__.py @@ -0,0 +1,5 @@ +# pylint: disable=wildcard-import +"""VISION network operators""" +from __future__ import absolute_import as _abs + +from .multibox import * diff --git a/topi/python/topi/cuda/ssd/multibox.py b/topi/python/topi/cuda/ssd/multibox.py new file mode 100644 index 000000000000..c22e7a513d7d --- /dev/null +++ b/topi/python/topi/cuda/ssd/multibox.py @@ -0,0 +1,360 @@ +# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements +"""SSD multibox operators""" +from __future__ import absolute_import as _abs +import math +import tvm + +from tvm import api + +import topi + +from topi.vision.ssd import multibox_prior +from topi.vision.ssd import multibox_detection +from topi.vision.ssd import multibox_transform_loc +from ..nms import nms + +def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): + """Low level IR routing for multibox_prior operator. + + Parameters + ---------- + data : Buffer + Input data buffer. + + out : Buffer + Output buffer. + + sizes : tuple of float + Tuple of sizes for anchor boxes. + + ratios : tuple of float + Tuple of ratios for anchor boxes. + + steps : Tuple of float + Priorbox step across y and x, -1 for auto calculation. + + offsets : tuple of int + Priorbox center offsets, y and x respectively. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + max_threads = int(math.sqrt(tvm.target.current_target(allow_none=False).max_num_threads)) + tx = tvm.thread_axis("threadIdx.x") + ty = tvm.thread_axis("threadIdx.y") + bx = tvm.thread_axis("blockIdx.x") + by = tvm.thread_axis("blockIdx.y") + ib = tvm.ir_builder.create() + p_out = ib.buffer_ptr(out) + in_height = data.shape[2] + in_width = data.shape[3] + nthread_tx = max_threads + nthread_bx = in_height // max_threads + 1 + nthread_ty = max_threads + nthread_by = in_width // max_threads + 1 + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(ty, "thread_extent", nthread_ty) + ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(by, "thread_extent", nthread_by) + + num_sizes = len(sizes) + num_ratios = len(ratios) + size_ratio_concat = sizes + ratios + steps_h = steps[0] if steps[0] > 0 else 1.0 / in_height + steps_w = steps[1] if steps[1] > 0 else 1.0 / in_width + offset_h = offsets[0] + offset_w = offsets[1] + + i = bx * max_threads + tx + j = by * max_threads + ty + with ib.if_scope((i < in_height)): + with ib.if_scope((j < in_width)): + center_h = (i + offset_h) * steps_h + center_w = (j + offset_w) * steps_w + + for k in range(num_sizes + num_ratios - 1): + w = tvm.select(k < num_sizes, + size_ratio_concat[k] * in_height / in_width / 2.0, + size_ratio_concat[0] * in_height / in_width * + math.sqrt(size_ratio_concat[k + 1]) / 2.0) + h = tvm.select(k < num_sizes, size_ratio_concat[k] / 2.0, + size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0) + count = (i * in_width * (num_sizes + num_ratios - 1) + + j * (num_sizes + num_ratios - 1) + k) * 4 + p_out[count] = center_w - w + p_out[count + 1] = center_h - h + p_out[count + 2] = center_w + w + p_out[count + 3] = center_h + h + + body = ib.get() + return body + + +@multibox_prior.register(["cuda", "gpu"]) +def multibox_prior_gpu(data, sizes=(1,), ratios=(1,), steps=(-1, -1), \ + offsets=(0.5, 0.5), clip=False): + """Generate prior(anchor) boxes from data, sizes and ratios. + + Parameters + ---------- + data : tvm.Tensor + 4-D with shape [batch, c_in, h_in, w_in]] + + sizes : tuple of float + Tuple of sizes for anchor boxes. + + ratios : tuple of float + Tuple of ratios for anchor boxes. + + steps : Tuple of float + Priorbox step across y and x, -1 for auto calculation. + + offsets : tuple of int + Priorbox center offsets, y and x respectively. + + clip : boolean + Whether to clip out-of-boundary boxes. + + Returns + ------- + out : tvm.Tensor + 3-D tensor with shape [1, h_in * w_in * (num_sizes + num_ratios - 1), 4] + """ + num_sizes = len(sizes) + num_ratios = len(ratios) + oshape = (1, data.shape[2] * data.shape[3] * (num_sizes + num_ratios - 1), 4) + out = tvm.extern(oshape, [data], lambda ins, outs: + multibox_prior_ir(ins[0], outs[0], sizes, ratios, steps, offsets), + tag="multibox_prior") + if clip: + out = topi.clip(out, 0, 1) + return out + + +def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, threshold, variances): + """Low level IR routing for transform location in multibox_detection operator. + + Parameters + ---------- + cls_prob : Buffer + Buffer of class probabilities. + + loc_pred : Buffer + Buffer of location regression predictions. + + anchor : Buffer + Buffer of prior anchor boxes. + + valid_count : Buffer + Buffer of number of valid output boxes. + + out : Buffer + Output buffer. + + clip : boolean + Whether to clip out-of-boundary boxes. + + threshold : float + Threshold to be a positive prediction. + + variances : tuple of float + Variances to be decoded from box regression output. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw, vh): + """Transform prior anchor box to output box through location predictions. + """ + al = anchor[anchor_base_idx] + at = anchor[anchor_base_idx + 1] + ar = anchor[anchor_base_idx + 2] + ab = anchor[anchor_base_idx + 3] + aw = ar - al + ah = ab - at + ax = (al + ar) / 2.0 + ay = (at + ab) / 2.0 + px = loc[loc_base_idx] + py = loc[loc_base_idx + 1] + pw = loc[loc_base_idx + 2] + ph = loc[loc_base_idx + 3] + ox = px * vx * aw + ax + oy = py * vy * ah + ay + ow = tvm.exp(pw * vw) * aw / 2.0 + oh = tvm.exp(ph * vh) * ah / 2.0 + return tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox - ow)), ox - ow), \ + tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy - oh)), oy - oh), \ + tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox + ow)), ox + ow), \ + tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy + oh)), oy + oh) + + batch_size = cls_prob.shape[0] + num_classes = cls_prob.shape[1] + num_anchors = cls_prob.shape[2] + + ib = tvm.ir_builder.create() + temp_score = ib.allocate('float32', (batch_size * (num_classes -1) * num_anchors, \ + ), name="temp_score", scope="global") + score = ib.allocate('float32', (batch_size * num_anchors, ), name="score", scope="local") + cls_id = ib.allocate('int32', (batch_size * num_anchors, ), name="id", scope="local") + flag = ib.allocate('int32', (batch_size * num_anchors, ), name="flag", scope="global") + max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("blockIdx.x") + nthread_tx = max_threads + nthread_bx = (batch_size * num_anchors * num_classes) // max_threads + 1 + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + p_cls_prob = ib.buffer_ptr(cls_prob) + p_loc_pred = ib.buffer_ptr(loc_pred) + p_anchor = ib.buffer_ptr(anchor) + p_valid_count = ib.buffer_ptr(valid_count) + p_out = ib.buffer_ptr(out) + with ib.if_scope(tid < batch_size * num_anchors * num_classes): + n = tid / (num_anchors * num_classes) + j = (tid % (num_anchors * num_classes)) / num_anchors + i = tid % num_anchors + with ib.if_scope(j > 0): + temp_score[n * num_anchors * num_classes + i * (num_classes - 1) + j-1] = \ + p_cls_prob[tid] + p_valid_count[n] = 0 + with ib.if_scope(tid < batch_size * num_anchors): + n = tid / num_anchors + i = tid % num_anchors + score[tid] = -1.0 + cls_id[tid] = 0 + with ib.for_range(0, num_classes-1, name="k") as k: + temp = temp_score[tid * (num_classes-1) + k] + cls_id[tid] = tvm.select(temp > score[tid], k + 1, cls_id[tid]) + score[tid] = tvm.make.Max(temp, score[tid]) + with ib.if_scope(tvm.all(cls_id[tid] > 0, score[tid] < threshold)): + cls_id[tid] = 0 + with ib.if_scope(cls_id[tid] > 0): + flag[tid] = 1 + with ib.else_scope(): + flag[tid] = 0 + with ib.if_scope(tid < batch_size): + with ib.for_range(0, num_anchors, name="k") as k: + with ib.if_scope(k > 0): + flag[tid * num_anchors + k] += flag[tid * num_anchors + k - 1] + p_valid_count[tid] = flag[tid * num_anchors + num_anchors - 1] + with ib.if_scope(tid < batch_size * num_anchors): + n = tid / num_anchors + i = tid % num_anchors + with ib.if_scope(cls_id[tid] > 0): + with ib.if_scope(tid == 0): + out_base_idx = n * num_anchors * 6 + with ib.else_scope(): + out_base_idx = n * num_anchors * 6 + flag[tid - 1] * 6 + p_out[out_base_idx] = cls_id[tid] - 1.0 + p_out[out_base_idx + 1] = score[tid] + p_out[out_base_idx + 2], p_out[out_base_idx + 3], p_out[out_base_idx + 4], \ + p_out[out_base_idx + 5] = transform_loc(p_loc_pred, tid * 4, p_anchor, i*4, + clip, variances[0], variances[1], + variances[2], variances[3]) + + body = ib.get() + return body + + +@multibox_transform_loc.register(["cuda", "gpu"]) +def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, + variances=(0.1, 0.1, 0.2, 0.2)): + """Location transformation for multibox detection + + Parameters + ---------- + cls_prob : tvm.Tensor + Class probabilities. + + loc_pred : tvm.Tensor + Location regression predictions. + + anchor : tvm.Tensor + Prior anchor boxes. + + clip : boolean + Whether to clip out-of-boundary boxes. + + threshold : float + Threshold to be a positive prediction. + + variances : tuple of float + Variances to be decoded from box regression output. + + Returns + ------- + ret : tuple of tvm.Tensor composed of + + out : tvm.Tensor + 3-D tensor with shape (batch_size, num_anchors, 6) + + valid_count : tvm.Tensor + 1-D tensor with shape (batch_size,), number of valid anchor boxes. + """ + batch_size = cls_prob.shape[0] + num_anchors = anchor.shape[1] + oshape = (batch_size, num_anchors, 6) + # Define data alignment for intermediate buffer + valid_count_dtype = "int32" + valid_count_buf = api.decl_buffer((batch_size,), valid_count_dtype, + "valid_count_buf", data_alignment=4) + out_buf = api.decl_buffer(oshape, cls_prob.dtype, "out_buf", data_alignment=8) + valid_count, out = \ + tvm.extern([(batch_size,), oshape], + [cls_prob, loc_pred, anchor], + lambda ins, outs: transform_loc_ir( + ins[0], ins[1], ins[2], outs[0], outs[1], clip, threshold, variances), + dtype=[valid_count_dtype, cls_prob.dtype], + out_buffers=[valid_count_buf, out_buf], + tag="multibox_transform_loc") + return [out, valid_count] + + +@multibox_detection.register(["cuda", "gpu"]) +def multibox_detection_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nms_threshold=0.5, + force_suppress=False, variances=(0.1, 0.1, 0.2, 0.2), nms_topk=-1): + """Convert multibox detection predictions. + + Parameters + ---------- + cls_prob : tvm.Tensor + Class probabilities. + + loc_pred : tvm.Tensor + Location regression predictions. + + anchor : tvm.Tensor + Prior anchor boxes. + + clip : boolean + Whether to clip out-of-boundary boxes. + + nms_threshold : float + Non-maximum suppression threshold. + + force_suppress : boolean + Whether to suppress all detections regardless of class_id. + + threshold : float + Threshold to be a positive prediction. + + variances : tuple of float + Variances to be decoded from box regression output. + + nms_topk : int + Keep maximum top k detections before nms, -1 for no limit. + + Returns + ------- + out : tvm.Tensor + 3-D tensor with shape (batch_size, num_anchors, 6) + """ + inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor, + clip, threshold, variances) + out = nms(inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk) + return out diff --git a/topi/python/topi/cuda/vision.py b/topi/python/topi/cuda/vision.py index 106d13665ad8..c5d94b5ab4de 100644 --- a/topi/python/topi/cuda/vision.py +++ b/topi/python/topi/cuda/vision.py @@ -1,9 +1,42 @@ -# pylint: disable=invalid-name, unused-variable, unused-argument +# pylint: disable=invalid-name, unused-variable, unused-argument, no-member """Schedule for vision operators""" from __future__ import absolute_import as _abs import tvm from .. import generic from .. import cpp +from .. import tag + +def _default_schedule(outs): + """Default schedule for gpu.""" + target = tvm.target.current_target() + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + def traverse(op): + """inline all one-to-one-mapping operators except the last stage (output)""" + if "nms" in op.tag: + sort = op.input_tensors[1] + score = s[sort].op.input_tensors[0] + fused = s[score].fuse(*s[score].op.axis) + num_thread = tvm.target.current_target(allow_none=False).max_num_threads + bx, tx = s[score].split(fused, factor=num_thread) + s[score].bind(bx, tvm.thread_axis("blockIdx.x")) + s[score].bind(tx, tvm.thread_axis("threadIdx.x")) + if tag.is_broadcast(op.tag): + if op not in s.outputs: + s[op].compute_inline() + else: + x = op.output(0) + fused = s[x].fuse(*s[x].op.axis) + num_thread = tvm.target.current_target(allow_none=False).max_num_threads + bx, tx = s[x].split(fused, factor=num_thread) + s[x].bind(bx, tvm.thread_axis("blockIdx.x")) + s[x].bind(tx, tvm.thread_axis("threadIdx.x")) + for tensor in op.input_tensors: + if tensor.op.input_tensors: + traverse(tensor.op) + + traverse(outs[0].op) + return s @generic.schedule_reorg.register(["cuda", "gpu"]) def schedule_reorg(outs): @@ -41,8 +74,25 @@ def schedule_region(outs): cpp_target = cpp.TEST_create_target(target.target_name) return cpp.cuda.schedule_region(cpp_target, outs) +@generic.schedule_nms.register(["cuda", "gpu"]) +def schedule_nms(outs): + """Schedule for non-maximum suppression + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of nms + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs) + @generic.schedule_multibox_prior.register(["cuda", "gpu"]) -def schedule_multibox_prior(out): +def schedule_multibox_prior(outs): """Schedule for multibox_prior operator. Parameters @@ -56,10 +106,28 @@ def schedule_multibox_prior(out): s: Schedule The computation schedule for multibox_prior. """ - raise RuntimeError("Currently multibox_prior only supports CPU.") + return _default_schedule(outs) + +@generic.schedule_multibox_transform_loc.register(["cuda", "gpu"]) +def schedule_multibox_transform_loc(outs): + """Schedule for multibox_transform_loc + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of + multibox_transform_loc in the format + of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs) @generic.schedule_multibox_detection.register(["cuda", "gpu"]) -def schedule_multibox_detection(out): +def schedule_multibox_detection(outs): """Schedule for multibox_detection operator. Parameters @@ -73,4 +141,4 @@ def schedule_multibox_detection(out): s: Schedule The computation schedule for multibox_detection. """ - raise RuntimeError("Currently multibox_detection only supports CPU.") + return _default_schedule(outs) diff --git a/topi/tests/python/test_topi_clip.py b/topi/tests/python/test_topi_clip.py index 52da4922e1d6..041565433bcc 100644 --- a/topi/tests/python/test_topi_clip.py +++ b/topi/tests/python/test_topi_clip.py @@ -20,23 +20,27 @@ def get_ref_data(): a_np, b_np = get_ref_data() def check_device(device): - if not tvm.module.enabled(device): + ctx = tvm.context(device, 0) + if not ctx.exist: print("Skip because %s is not enabled" % device) return - ctx = tvm.cpu(0) if device == "llvm" else tvm.gpu(0) + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_injective(B) + a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) f = tvm.build(s, [A, B], device, name="clip") f(a, b) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - for device in ['llvm']: + for device in ['llvm', 'opencl']: check_device(device) def test_clip(): - verify_clip(1024, -127, 127, 'int8') - verify_clip(1024, -127, 127, 'int16') verify_clip(1024, -127, 127, 'float32') + verify_clip(1024, -127, 127, 'int16') + verify_clip(1024, -127, 127, 'int8') if __name__ == "__main__": diff --git a/topi/tests/python/test_topi_conv2d_nchw.py b/topi/tests/python/test_topi_conv2d_nchw.py index 0d55a9163466..7f3935f3aad7 100644 --- a/topi/tests/python/test_topi_conv2d_nchw.py +++ b/topi/tests/python/test_topi_conv2d_nchw.py @@ -8,6 +8,8 @@ from topi.util import get_const_tuple def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1): + print("Workload: (%d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding)) + in_height = in_width = in_size A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') @@ -59,7 +61,7 @@ def check_device(device): def test_conv2d_nchw(): - # ResNet18 worklaods + # ResNet18 workloads verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3) verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1) verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0) @@ -72,6 +74,21 @@ def test_conv2d_nchw(): verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1) verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0) verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1) + # ResNet 50 workloads + verify_conv2d_nchw(1, 64, 56, 256, 1, 1, 0) + verify_conv2d_nchw(1, 256, 56, 64, 1, 1, 0) + verify_conv2d_nchw(1, 256, 56, 128, 1, 2, 0) + verify_conv2d_nchw(1, 128, 28, 512, 1, 1, 0) + verify_conv2d_nchw(1, 256, 56, 512, 1, 2, 0) + verify_conv2d_nchw(1, 512, 28, 128, 1, 1, 0) + verify_conv2d_nchw(1, 512, 28, 256, 1, 2, 0) + verify_conv2d_nchw(1, 256, 14, 1024, 1, 1, 0) + verify_conv2d_nchw(1, 512, 28, 1024, 1, 2, 0) + verify_conv2d_nchw(1, 1024, 14, 256, 1, 1, 0) + verify_conv2d_nchw(1, 1024, 14, 512, 1, 2, 0) + verify_conv2d_nchw(1, 512, 7, 2048, 1, 2, 0) + verify_conv2d_nchw(1, 1024, 14, 2048, 1, 2, 0) + verify_conv2d_nchw(1, 2048, 7, 512, 1, 1, 0) # Vgg16 workloads verify_conv2d_nchw(1, 128, 122, 128, 3, 1, 1) # Super resolution workloads diff --git a/topi/tests/python/test_topi_vision.py b/topi/tests/python/test_topi_vision.py index 3c624726d562..959b10f82ca5 100644 --- a/topi/tests/python/test_topi_vision.py +++ b/topi/tests/python/test_topi_vision.py @@ -14,7 +14,6 @@ def test_nms(): nms_threshold = 0.7 force_suppress = True nms_topk = 2 - out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk) np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80], [0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79], @@ -31,6 +30,10 @@ def check_device(device): return print("Running on target: %s" % device) with tvm.target.create(device): + if device == 'llvm': + out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk) + else: + out = topi.cuda.nms(data, valid_count, nms_threshold, force_suppress, nms_topk) s = topi.generic.schedule_nms(out) tvm_data = tvm.nd.array(np_data, ctx) @@ -40,13 +43,12 @@ def check_device(device): f(tvm_data, tvm_valid_count, tvm_out) np.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4) - for device in ['llvm']: + for device in ['llvm', 'opencl']: check_device(device) def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, 0.5), clip=False): data = tvm.placeholder(dshape, name="data") - out = ssd.multibox_prior(data, sizes, ratios, steps, offsets, clip) dtype = data.dtype input_data = np.random.uniform(size=dshape).astype(dtype) @@ -88,15 +90,19 @@ def check_device(device): return print("Running on target: %s" % device) with tvm.target.create(device): + if device == 'llvm': + out = ssd.multibox_prior(data, sizes, ratios, steps, offsets, clip) + else: + out = topi.cuda.ssd.multibox_prior(data, sizes, ratios, steps, offsets, clip) s = topi.generic.schedule_multibox_prior(out) tvm_input_data = tvm.nd.array(input_data, ctx) tvm_out = tvm.nd.array(np.zeros(oshape, dtype=dtype), ctx) f = tvm.build(s, [data, out], device) f(tvm_input_data, tvm_out) - np.testing.assert_allclose(tvm_out.asnumpy(), np_out, rtol=1e-4) + np.testing.assert_allclose(tvm_out.asnumpy(), np_out, rtol=1e-3) - for device in ['llvm']: + for device in ['llvm', 'opencl']: check_device(device) @@ -113,7 +119,6 @@ def test_multibox_detection(): cls_prob = tvm.placeholder((batch_size, num_anchors, num_classes), name="cls_prob") loc_preds = tvm.placeholder((batch_size, num_anchors * 4), name="loc_preds") anchors = tvm.placeholder((1, num_anchors, 4), name="anchors") - out = ssd.multibox_detection(cls_prob, loc_preds, anchors) # Manually create test case np_cls_prob = np.array([[[0.2, 0.5, 0.3], [0.25, 0.3, 0.45], [0.7, 0.1, 0.2]]]) @@ -131,6 +136,10 @@ def check_device(device): return print("Running on target: %s" % device) with tvm.target.create(device): + if device == 'llvm': + out = ssd.multibox_detection(cls_prob, loc_preds, anchors) + else: + out = topi.cuda.ssd.multibox_detection(cls_prob, loc_preds, anchors) s = topi.generic.schedule_multibox_detection(out) tvm_cls_prob = tvm.nd.array(np_cls_prob.astype(cls_prob.dtype), ctx) @@ -141,11 +150,11 @@ def check_device(device): f(tvm_cls_prob, tvm_loc_preds, tvm_anchors, tvm_out) np.testing.assert_allclose(tvm_out.asnumpy(), expected_np_out, rtol=1e-4) - for device in ['llvm']: + for device in ['llvm', 'opencl']: check_device(device) if __name__ == "__main__": test_nms() test_multibox_prior() - test_multibox_detection() \ No newline at end of file + test_multibox_detection()