Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelize cumsum in get_valid_counts #7123

Merged
merged 6 commits into from
Dec 31, 2020

Conversation

mbrookhart
Copy link
Contributor

@mbrookhart mbrookhart commented Dec 17, 2020

As a followup to #6839 , this parallelizes the cumsum in get_valid_counts using an upsweep/downsweep tree-based prefix sum algorithm, similar to what I did in #7099.

On my 1070 Ti, testing deploy_ssd_gluoncv.py, I previously reported that get_valid_counts took 3674.62 microseconds this reduces that to 495.8 258.497 microseconds.

@masahi has expressed interest in implementing a more general prefix scan for other ops, as future work I expect we'll refactor this and do possible cache optimization.

Thanks

cc @Laurawly @zhiics @kevinthesun

@masahi
Copy link
Member

masahi commented Dec 18, 2020

@mbrookhart Can you revive disabled topi get_valid_count test? It seems this test needs some updating.

@tvm.testing.uses_gpu
@pytest.mark.skip(
"Skip this test as it is intermittent."
"See https://github.com/apache/tvm/pull/4901#issuecomment-595040094"
)
def test_get_valid_counts():

@mbrookhart
Copy link
Contributor Author

ping @Laurawly, any chance you could take a look?

@masahi
Copy link
Member

masahi commented Dec 31, 2020

@Laurawly The plan is after we merge this first, we will generalize the cumsum IR in this PR into a reusable, exclusive scan primitive. After that, we can update our CUDA argwhere implementation to use ex scan + compaction, and introduce numpy style cumsum operator.

@Laurawly
Copy link
Contributor

@Laurawly The plan is after we merge this first, we will generaliz the cumsum IR in this PR into a reusable, exclusive scan primitive. After that, we can update our CUDA argwhere implementation to use ex scan + compaction, and introduce numpy style cumsum operator.

Sure, I can merge this first.

@Laurawly Laurawly merged commit c02c9c5 into apache:main Dec 31, 2020
@trevor-m
Copy link
Contributor

trevor-m commented Jan 5, 2021

Hi @mbrookhart thanks for this performance improvment!

I found that this PR is causing CUDA: an illegal memory access was encountered during inference for a TensorFlow SSD object detection model. I can't reproduce it in a standalone unit test, so I think there may be some race condition or code relying on unitialized memory. I'll let you know if I find out anything more.

@mbrookhart
Copy link
Contributor Author

Thanks, Trevor. If you can share the model script you're using, I can also work to debug today.

@mbrookhart mbrookhart deleted the get_valid_counts_prefix_sum branch January 5, 2021 17:12
@masahi
Copy link
Member

masahi commented Jan 5, 2021

I can reproduce the issue by running ssd test in tensorflow/test_forward.py with cuda target (I looked at this test yesterday for my PR, so I have a fresh memory):

terminate called after throwing an instance of 'dmlc::Error'
  what():  [05:42:13] /home/masa/projects/dev/tvm/src/runtime/cuda/cuda_device_api.cc:126: 
---------------------------------------------------------------
An internal invariant was violated during the execution of TVM.
Please read TVM's error reporting guidelines.
More details can be found here: https://discuss.tvm.ai/t/error-reporting/7793.
---------------------------------------------------------------
  Check failed: e == cudaSuccess || e == cudaErrorCudartUnloading == false: CUDA: an illegal memory access was encountered
Stack trace:
  [bt] (0) /home/masa/projects/dev/tvm/build/libtvm.so(+0x14aa8e8) [0x7f4fcb8ca8e8]
  [bt] (1) /home/masa/projects/dev/tvm/build/libtvm.so(tvm::runtime::CUDADeviceAPI::FreeDataSpace(DLContext, void*)+0xe4) [0x7f4fcb8cabe4]
  [bt] (2) /home/masa/projects/dev/tvm/build/libtvm.so(tvm::runtime::NDArray::Internal::DefaultDeleter(tvm::runtime::Object*)+0x5b) [0x7f4fcb8593fb]
  [bt] (3) /home/masa/projects/dev/tvm/build/libtvm.so(tvm::runtime::NDArray::CopyTo(DLContext const&) const+0x325) [0x7f4fcb5e4915]
  [bt] (4) /home/masa/projects/dev/tvm/build/libtvm.so(tvm::runtime::vm::CopyTo(tvm::runtime::ObjectRef, DLContext const&)+0x311) [0x7f4fcb884b11]
  [bt] (5) /home/masa/projects/dev/tvm/build/libtvm.so(tvm::runtime::vm::VirtualMachine::RunLoop()+0x2aee) [0x7f4fcb880dde]
  [bt] (6) /home/masa/projects/dev/tvm/build/libtvm.so(tvm::runtime::vm::VirtualMachine::Invoke(tvm::runtime::vm::VMFunction const&, std::vector<tvm::runtime::ObjectRef, std::allocator<tvm::runtime::ObjectRef> > const&)+0x27) [0x7f4fcb881c17]
  [bt] (7) /home/masa/projects/dev/tvm/build/libtvm.so(+0x14621f0) [0x7f4fcb8821f0]
  [bt] (8) /home/masa/projects/dev/tvm/build/libtvm.so(TVMFuncCall+0x63) [0x7f4fcb835613]

@trevor-m Are you sure this is caused by get_valid_counts change? I've also changed NMS in #7172, I hope that change is fine.

@masahi
Copy link
Member

masahi commented Jan 5, 2021

hmm strange, after running the ssd test on GPU a few times, I cannot reproduce the error anymore. Could this error be random?

One annoying thing about this model is that compilation time is extremely slow. It also requires increasing the stack size limit, otherwise it segfaults.

@mbrookhart
Copy link
Contributor Author

Yeah, ouch:
447.33s (0:07:27)

I'm not needing to increase the stack limit, and I haven't gotten this test to fail yet.

@trevor-m
Copy link
Contributor

trevor-m commented Jan 5, 2021

hmm strange, after running the ssd test on GPU a few times, I cannot reproduce the error anymore. Could this error be random?

One annoying thing about this model is that compilation time is extremely slow. It also requires increasing the stack size limit, otherwise it segfaults.

Yeah the error is a bit random. However, I was able to reproduce it 100% of the time with TRT offload enabled. I can share a script shortly.

@trevor-m Are you sure this is caused by get_valid_counts change? I've also changed NMS in #7172, I hope that change is fine.

Yeah, I did a git bisect to determine this PR was the source of the issue, and #7172 was fine.

@anijain2305
Copy link
Contributor

hmm strange, after running the ssd test on GPU a few times, I cannot reproduce the error anymore. Could this error be random?

One annoying thing about this model is that compilation time is extremely slow. It also requires increasing the stack size limit, otherwise it segfaults.

Maybe it depends on the input data. Trevor and I ran it across a bunch of models, and it fails for few of them (not all). I believe that it can be because of input data (as number of boxes etc change with input image)

@mbrookhart
Copy link
Contributor Author

@trevor-m I'm in mountain time, so I'll need to leave in about half an hour. If you can post the script that consistently fails tonight, I'll jump in first thing tomorrow morning and start hunting for which line causes the issue.

@masahi
Copy link
Member

masahi commented Jan 5, 2021

@anijain2305 @trevor-m We should definitely use a fixed, real image for CI testing, like pytorch MaskRCNN test does. Please send a PR

img = "test_street_small.jpg"
img_url = (
"https://raw.githubusercontent.com/dmlc/web-data/"
"master/gluoncv/detection/street_small.jpg"
)
download(img_url, img)

@trevor-m
Copy link
Contributor

trevor-m commented Jan 7, 2021

I ran the model that was failing (ssd_mobilenet_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03) under cuda-gdb and was able to get some information from the crash:

CUDA Exception: Warp Out-of-range Address
The exception was triggered at PC 0x55556b2ef890

Thread 1 "python3" received signal CUDA_EXCEPTION_5, Warp Out-of-range Address.
[Switching focus to CUDA kernel 2, grid 894452, block (0,0,0), thread (864,0,0), device 0, sm 0, warp 24, lane 0]
0x000055556b2ef8b0 in fused_vision_non_max_suppression_kernel1<<<(1,1,1),(1024,1,1)>>> ()

@masahi Any thoughts?

@masahi
Copy link
Member

masahi commented Jan 7, 2021

Does this mean NMS and not get_valid_counts kernel have an issue? I recognize the thread launch config (1,1,1),(1024,1,1), this is due to my NMS change. But that kernel should be fused_vision_non_max_suppression_kernel2 and not fused_vision_non_max_suppression_kernel1 as shown above, so this is weird.

@mbrookhart
Copy link
Contributor Author

mbrookhart commented Jan 7, 2021

Looking at the code, assuming you have thrust enabled, this should be kernel0:

score_tensor = te.extern(
[score_shape],
[data],
lambda ins, outs: _fetch_score_ir(
ins[0],
outs[0],
score_axis,
),
dtype=[data.dtype],
in_buffers=[data_buf],
out_buffers=[score_buf],
name="fetch_score",
tag="fetch_score",
)

the thrust argsort wont get a number:
sort_tensor = argsort_thrust(
score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype
)

And this should be 1:
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
i = by
base_idx = i * num_anchors * box_data_length
with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
# Reorder output
nkeep = if_then_else(
tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]
)
j = bx * max_threads + tx
with ib.if_scope(j < num_anchors):
box_indices[i * num_anchors + j] = -1
with ib.if_scope(j < nkeep):
# Fill in out with sorted boxes
with ib.for_range(0, box_data_length) as k:
out[(base_idx + j * box_data_length + k)] = data[
(base_idx + sorted_index[i * num_anchors + j] * box_data_length + k)
]
with ib.else_scope():
# Indices > nkeep are discarded
with ib.if_scope(j < num_anchors):
with ib.for_range(0, box_data_length) as k:
out[(base_idx + j * box_data_length + k)] = -1.0
with ib.else_scope():
with ib.if_scope(j < valid_count[i]):
with ib.for_range(0, box_data_length) as k:
offset = base_idx + j * box_data_length + k
out[offset] = data[offset]
box_indices[i * num_anchors + j] = j

That could have threads (1,1,1),(1024,1,1) if we have batch_size=1 and num_anchors <= 1024. I'm not seeing anything in there that jumps out as having an issue though. Every use of j is gaurded by and if scope with j<num_anchors, j< nkeep, or j< valid_count, and nkeep <= valid_count. The only way it could fail is if valid_count > num_anchors...

So possibly it's failing because my changes to get_valid_count are returning the wrong valid_count.

@trevor-m any chance we can dump the inputs/attrs for get_valid_count so I can make a unit test to check that hypothesis? I haven't been able to get it to fail with random inputs, but possibly there's an edge case in my exclusive_scan algorithm for this input data.

@trevor-m
Copy link
Contributor

trevor-m commented Jan 7, 2021

Thanks for looking into it and finding that info @mbrookhart !

Here is the relevant relay graph:

boxes = relay.var("boxes", shape=(1, relay.Any(), 5), dtype="float32")

max_output_size = relay.shape_of(boxes)
max_output_size = relay.strided_slice(max_output_size, begin=[1], end=[2], strides=[1])
max_output_size = relay.squeeze(max_output_size)
max_output_size = relay.minimum(relay.const(100, dtype="int32"), max_output_size)

ct, data, indices = relay.vision.get_valid_counts(
    boxes, score_threshold=0.0, id_index=-1, score_index=0
)

nms_ret = relay.vision.non_max_suppression(
    data=boxes,
    valid_count=ct,
    indices=indices,
    max_output_size=max_output_size,
    iou_threshold=0.6,
    force_suppress=True,
    top_k=-1,
    coord_start=1,
    score_index=0,
    id_index=-1,
    return_indices=True,
    invalid_to_bottom=False,
)

The input shape is [1, 0, 5] during the model execution when the crash occurs. I haven't been able to reproduce with this standalone test yet. Maybe there is an edge case for size 0 max_output_size or num_anchors?

@mbrookhart
Copy link
Contributor Author

Ooh, interesting, doing NMS on no boxes, I'll take a look with that idea.

@mbrookhart
Copy link
Contributor Author

mbrookhart commented Jan 7, 2021

I don't this this is valid if num_anchors is zero, it could lead to undefined behavior. Could you wrap that in an with ib.if_scope(num_anchors > 0) and see if that fixes the problem?

with ib.new_scope():
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", batch_size)
with ib.if_scope(bx < batch_size):
valid_count[bx] = valid_indices[(bx + 1) * num_anchors - 1]
valid_indices[(bx + 1) * num_anchors - 1] = 0

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Jan 7, 2021
@anijain2305
Copy link
Contributor

anijain2305 commented Jan 7, 2021

I don't this this is valid if num_anchors is zero, it could lead to undefined behavior. Could you wrap that in an with ib.if_scope(num_anchors > 0) and see if that fixes the problem?

with ib.new_scope():
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", batch_size)
with ib.if_scope(bx < batch_size):
valid_count[bx] = valid_indices[(bx + 1) * num_anchors - 1]
valid_indices[(bx + 1) * num_anchors - 1] = 0

@@ -210,8 +211,9 @@ def get_valid_indices_ir(valid_boxes, valid_count, valid_indices):
         bx = te.thread_axis("blockIdx.x")
         ib.scope_attr(bx, "thread_extent", batch_size)
         with ib.if_scope(bx < batch_size):
-            valid_count[bx] = valid_indices[(bx + 1) * num_anchors - 1]
-            valid_indices[(bx + 1) * num_anchors - 1] = 0
+            with ib.if_scope(num_anchors > 0):
+                valid_count[bx] = valid_indices[(bx + 1) * num_anchors - 1]
+                valid_indices[(bx + 1) * num_anchors - 1] = 0

     with ib.for_range(0, lim, dtype="int64") as l2_width:
         width = 2 << (lim - l2_width - 1)

I tried this yesterday. Unfortunately, this is not the source. The test still failed.

@mbrookhart
Copy link
Contributor Author

Alas. I would still very much appreciate the script to reproduce this so I can hunt it down.

@mbrookhart
Copy link
Contributor Author

lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64"
)

Log(0) is undefined, we should probably just wrap the entire thing in a if num_anchors > 0

@anijain2305
Copy link
Contributor

Alas. I would still very much appreciate the script to reproduce this so I can hunt it down.

Yes, Trevor is working on it. It needs TRT workflow and thats why the delay.

lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64"
)

Log(0) is undefined, we should probably just wrap the entire thing in a if num_anchors > 0

Let me try this as well.

@trevor-m
Copy link
Contributor

trevor-m commented Jan 7, 2021

Thanks @mbrookhart , I tried that but it didn't fix the error.

Here is a script to reproduce the error: https://gist.github.com/trevor-m/f44d3d0e7edcaee12722e518e5959b82

I also noticed this line in the kernel where cuda-gdb found a crash: https://github.com/apache/tvm/blob/main/python/tvm/topi/cuda/nms.py#L545
Shouldn't nthread_bx be 0 if num_anchors is 0? Does this mean num_anchors is wrong?

@mbrookhart
Copy link
Contributor Author

Thanks, I'll try to reproduce. You're building with thrust and TRT, right?

you can't compile a cuda kernel with zero threads, so we always make sure it's at least 1:

if attr_key == "thread_extent":
value = op.max(1, value)

@trevor-m
Copy link
Contributor

trevor-m commented Jan 7, 2021

Thanks, I'll try to reproduce. You're building with thrust and TRT, right?

you can't compile a cuda kernel with zero threads, so we always make sure it's at least 1:

if attr_key == "thread_extent":
value = op.max(1, value)

Yes, that's right, thrust + TRT. Thank you for your help with debugging this.

@mbrookhart
Copy link
Contributor Author

For posperity, @trevor-m and I did some offline debugging yesterday, and #7229 seems to fix the issue.

tkonolige pushed a commit to tkonolige/incubator-tvm that referenced this pull request Jan 11, 2021
* Parallelize cumsum in get_valid_counts

* make the scan loop exclusive

* switch to directly using exclusive scan

* perform inner loop of final writes on anchor threads

* fix flaky test

fix lint

* remove final cuda kernel

Co-authored-by: masa <masa@pop-os.localdomain>
@mbrookhart mbrookhart restored the get_valid_counts_prefix_sum branch January 19, 2021 16:44
TusharKanekiDey pushed a commit to TusharKanekiDey/tvm that referenced this pull request Jan 20, 2021
* Parallelize cumsum in get_valid_counts

* make the scan loop exclusive

* switch to directly using exclusive scan

* perform inner loop of final writes on anchor threads

* fix flaky test

fix lint

* remove final cuda kernel

Co-authored-by: masa <masa@pop-os.localdomain>
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Jan 21, 2021
* Parallelize cumsum in get_valid_counts

* make the scan loop exclusive

* switch to directly using exclusive scan

* perform inner loop of final writes on anchor threads

* fix flaky test

fix lint

* remove final cuda kernel

Co-authored-by: masa <masa@pop-os.localdomain>
electriclilies pushed a commit to electriclilies/tvm that referenced this pull request Feb 18, 2021
* Parallelize cumsum in get_valid_counts

* make the scan loop exclusive

* switch to directly using exclusive scan

* perform inner loop of final writes on anchor threads

* fix flaky test

fix lint

* remove final cuda kernel

Co-authored-by: masa <masa@pop-os.localdomain>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants