Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle duplicate indices in Numba implementation of AdvancedIncSubtensor1 #1081

Closed
wants to merge 5 commits into from

Conversation

aseyboldt
Copy link
Contributor

@aseyboldt aseyboldt commented Jul 28, 2022

Here are a few important guidelines and requirements to check before your PR can be merged:

  • There is an informative high-level description of the changes.
  • The description and/or commit message(s) references the relevant GitHub issue(s).
  • pre-commit is installed and set up.
  • The commit messages follow these guidelines.
  • The commits correspond to relevant logical changes, and there are no commits that fix changes introduced by other commits in the same branch/BR.
  • There are tests covering the changes introduced in the PR.

First, this PR removes the incorrect numba implementation of AdvancedIncSubtensor, so this will now just fall back to objectmode, and be slow and correct. (We should also provide a numba impl for this, but that will be a separate PR).

But we do add a new implementation for AdvancedIncSubtensor1 (the much more common case). Here, we also take advantage of the fact that sometimes we know the indices beforehand, so we can simplify bounds checks, and generate cleaner and faster assembly code.

Also a one line fix for cumop accidentally made its way into the PR, but this is simple enough that maybe we can just keep it here? (it has its own commit).
The problem was that the numba.prange loop had data races.

fixes #603

@codecov
Copy link

codecov bot commented Jul 28, 2022

Codecov Report

Merging #1081 (1e6be72) into main (f2a7fb9) will decrease coverage by 0.04%.
The diff coverage is 85.48%.

❗ Current head 1e6be72 differs from pull request most recent head f3f65ad. Consider uploading reports for the commit f3f65ad to get more accurate results

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1081      +/-   ##
==========================================
- Coverage   79.28%   79.24%   -0.05%     
==========================================
  Files         159      152       -7     
  Lines       48111    48002     -109     
  Branches    10937    10922      -15     
==========================================
- Hits        38145    38038     -107     
- Misses       7454     7458       +4     
+ Partials     2512     2506       -6     
Impacted Files Coverage Δ
aesara/link/numba/dispatch/basic.py 90.73% <85.24%> (-1.34%) ⬇️
aesara/link/numba/dispatch/extra_ops.py 98.00% <100.00%> (ø)
aesara/tensor/basic_opt.py 85.90% <0.00%> (-14.10%) ⬇️
aesara/tensor/subtensor_opt.py 87.12% <0.00%> (-12.88%) ⬇️
aesara/tensor/math_opt.py 87.27% <0.00%> (-12.73%) ⬇️
aesara/tensor/opt_uncanonicalize.py 96.03% <0.00%> (-3.97%) ⬇️
aesara/ifelse.py 49.71% <0.00%> (-1.30%) ⬇️
aesara/link/numba/dispatch/random.py 97.74% <0.00%> (-1.09%) ⬇️
aesara/sandbox/multinomial.py 75.91% <0.00%> (-0.61%) ⬇️
... and 58 more

@brandonwillard brandonwillard added bug Something isn't working important Numba Involves Numba transpilation labels Jul 28, 2022
@aseyboldt aseyboldt changed the title Fix numba AdvancedIncSubtensor Fix numba AdvancedIncSubtensor1 Jul 28, 2022
Comment on lines 538 to 545
@numba_njit
def check(z, vals):
if max_idx >= len(z):
raise IndexError(msg_pos)
if max(-min_idx, 0) > len(z):
raise IndexError(msg_neg)
if len(idxs) != len(vals):
raise ValueError("Incompatible dimensions during indexing")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We generally try to avoid manual checks at the Numba level. Partly because the same checks can be introduced downstream by Numba itself. For instance, if the values are constants, then some such checks occur at compile-time (e.g. during constant folding-like optimization passes), and at run-time in the code generated by zip, for, and z[idx].

Also, we want to honor the Numba-level options for disabling things like bounds checks. If we add similar checks of our own, then we can end up effectively overriding such options.

There are definitely cases in which manual checks are warranted, but it's not clear to me whether or not this is one of those cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point of that is that if we do it manually, we can always set boundscheck=False for that function without any safety issues (if the indices are known at compile time, as indicated in the code.)
I checked that this cleans up the generated assembly quite a bit, and also speeds up the execution of the function.
I've seen a lot of models where a big share of the execution time is in this op, so even smaller performance improvements matter.

I think even in general manual checks might be a good idea in many cases, often the optimizer will then optimize away the later checks and we end up with cleaner and faster asm code. Generally, the fewer checks happen within the loop the better.

@aseyboldt aseyboldt marked this pull request as ready for review July 28, 2022 23:10
@aseyboldt
Copy link
Contributor Author

Is it possible that codecov doesn't pick up changes when I force-push? Somehow I still see the old version of the patch in the coverage report...

Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you demonstrate how these manually added checks and constant case specializations are advantageous and/or not already covered by Numba's optimizations? I'm very hesitant to incorporate this additional complexity without some kind of demonstrable benefit.

@brandonwillard brandonwillard changed the title Fix numba AdvancedIncSubtensor1 Handle duplicate indices in Numba implementation of AdvancedIncSubtensor1 Aug 4, 2022
@aseyboldt
Copy link
Contributor Author

Comparison of the constant index case with the non-constant index case:

%env NUMBA_BOUNDSCHECK=0

import aesara
import aesara.tensor as at
import numpy as np

n, k = 100_000, 100
idxs_vals = np.random.randint(k, size=n)
#idxs_vals.sort()
x_vals = np.random.randn(k)
a_vals = np.random.randn(n)

x = at.dvector("x")
a = at.dvector("d")
idxs = at.vector("idx", dtype=np.int64)
out = at.inc_subtensor(x[idxs], a)
func = aesara.function([idxs, x, a], out, mode="NUMBA")

func_inner = func.vm.jit_fn
_ = func_inner(idxs_vals, x_vals, a_vals)
print("time with non-const index:")
%timeit func_inner(idxs_vals, x_vals, a_vals)

x = at.dvector("x")
a = at.dvector("d")
out = at.inc_subtensor(x[idxs_vals], a)
func = aesara.function([x, a], out, mode="NUMBA")

func_inner = func.vm.jit_fn
func_inner(x_vals, a_vals);
print("time with const index:")
%timeit func_inner(x_vals, a_vals)
time with non-const index:
90.9 µs ± 3.61 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
time with const index:
62.7 µs ± 974 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

So the non-const index case is about 1.4x slower.
If we enable boundschecks, this increases to ~2x.

The asm of the non-const version without boundschecks looks ok, but it has to deal with the possibility of negative indices:

.LBB3_13:
	movq	(%rdx), %rax
	vmovsd	(%rsi), %xmm0
	movq	%rax, %rdi
	sarq	$63, %rdi
	andq	%r12, %rdi
	addq	%rax, %rdi
	vaddsd	(%rbx,%rdi,8), %xmm0, %xmm0
	vmovsd	%xmm0, (%rbx,%rdi,8)
	leaq	(%r9,%rdx), %rax
	movq	(%r9,%rdx), %rdx
	leaq	(%rcx,%rsi), %rdi
	vmovsd	(%rcx,%rsi), %xmm0
	movq	%rdx, %rsi
	sarq	$63, %rsi
	andq	%r12, %rsi
	addq	%rdx, %rsi
	vaddsd	(%rbx,%rsi,8), %xmm0, %xmm0
	vmovsd	%xmm0, (%rbx,%rsi,8)
	leaq	(%r9,%rax), %rsi
	movq	(%r9,%rax), %rax
	leaq	(%rcx,%rdi), %rbp
	vmovsd	(%rcx,%rdi), %xmm0
	movq	%rax, %rdx
	sarq	$63, %rdx
	andq	%r12, %rdx
	addq	%rax, %rdx
	vaddsd	(%rbx,%rdx,8), %xmm0, %xmm0
	vmovsd	%xmm0, (%rbx,%rdx,8)
	leaq	(%r9,%rsi), %rdx
	movq	(%r9,%rsi), %rax
	vmovsd	(%rcx,%rbp), %xmm0
	movq	%rax, %rdi
	sarq	$63, %rdi
	andq	%r12, %rdi
	addq	%rax, %rdi
	vaddsd	(%rbx,%rdi,8), %xmm0, %xmm0
	leaq	(%rcx,%rbp), %rsi
	vmovsd	%xmm0, (%rbx,%rdi,8)
	addq	%r9, %rdx
	addq	%rcx, %rsi
	addq	$-4, %r15
	jne	.LBB3_13

If we enable boundschecks, we get branching in the loop:

.LBB3_8:
	movq	(%rdx), %rbx
	movq	%rbx, %rax
	sarq	$63, %rax
	andq	%r12, %rax
	addq	%rbx, %rax
	cmpq	%r12, %rax
	jge	.LBB3_10
	testq	%rax, %rax
	js	.LBB3_10
	vmovsd	(%rdi), %xmm0
	vaddsd	(%rbp,%rax,8), %xmm0, %xmm0
	vmovsd	%xmm0, (%rbp,%rax,8)
	addq	%rcx, %rdx
	addq	%rsi, %rdi
	decq	%r14
	jne	.LBB3_8

In comparison the constant-index case looks pretty nice:

.LBB3_10:
	movq	(%rdi,%rbx), %rax
	vmovsd	(%rcx), %xmm0
	vaddsd	(%rbp,%rax,8), %xmm0, %xmm0
	vmovsd	%xmm0, (%rbp,%rax,8)
	movq	8(%rdi,%rbx), %rax
	vmovsd	(%r8,%rcx), %xmm0
	vaddsd	(%rbp,%rax,8), %xmm0, %xmm0
	vmovsd	%xmm0, (%rbp,%rax,8)
	movq	16(%rdi,%rbx), %rax
	vmovsd	(%rcx,%r8,2), %xmm0
	vaddsd	(%rbp,%rax,8), %xmm0, %xmm0
	vmovsd	%xmm0, (%rbp,%rax,8)
	movq	24(%rdi,%rbx), %rax
	vmovsd	(%rsi,%rcx), %xmm0
	vaddsd	(%rbp,%rax,8), %xmm0, %xmm0
	vmovsd	%xmm0, (%rbp,%rax,8)
	addq	%rdx, %rcx
	addq	$32, %rdi
	cmpq	$8192, %rdi
	jne	.LBB3_10

And we get safe indexing even without boundscheck=True.

@aseyboldt
Copy link
Contributor Author

This also raises the question, how we want to deal with out of bounds access by default.
I'm not really comfortable with a default implementation that doesn't check bounds for indexing. Should we just enable boundschecking by default using a config option (right now we use the numba default boundscheck=False), and allow users to overwrite that if they ask explicitly? Or should we write ops like indexing that use user-input to be safe even when boundscheck=False?

@brandonwillard
Copy link
Member

brandonwillard commented Aug 6, 2022

You've demonstrated that there's possibly a clear difference between constant and non-constant inputs, but we really need to know whether or not all the extra code is providing value, and only a comparison with and without it would help determine that. Also, it's better if you provide the generated LLVM IR instead of the ASM generated for your machine.

@aseyboldt
Copy link
Contributor Author

I turned part of it into a rewrite, that makes it a bit cleaner. Apart from that I'm not really sure what extra code you are referring to.

Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Numba transpilations look better, but we don't need the manual bounds check additions, especially not at the Op and/or rewrite levels.

Bounds-checking in Numba is entirely Numba's responsibility. If there's a bug, we can work around it with such additions, but I'm not aware of one; otherwise, we should support Numba options like bounds-checking in the ways we currently support other, similar options (e.g. aesara.config).

@aseyboldt
Copy link
Contributor Author

numbas default for boundschecks is False, so unless we change that, this means that at.as_tensor_variable(np.zeros(2))[3] has undefined behavior. I think this is terrible API. We don't have this problem in most numba ops, because most of the time user input can't make us access incorrect memory even if boundschecks are off, because we (hopefully) control the bounds of loops correctly.
So I think we really need boundschecks of some sort by default.
We can however sometimes tell if boundschecks are unnecessary and optimize them away (something that numba can't do on its own), so why wouldn't we?

@brandonwillard
Copy link
Member

brandonwillard commented Aug 15, 2022

numbas default for boundschecks is False, so unless we change that, this means that at.as_tensor_variable(np.zeros(2))[3] has undefined behavior. I think this is terrible API.
We don't have this problem in most numba ops, because most of the time user input can't make us access incorrect memory even if boundschecks are off, because we (hopefully) control the bounds of loops correctly.
So I think we really need boundschecks of some sort by default. We can however sometimes tell if boundschecks are unnecessary and optimize them away (something that numba can't do on its own), so why wouldn't we?

Aesara's responsibility is to faithfully preserve the results of explicitly defined computations for valid inputs when transpiling to Numba and other targets. Since most errors aren't specified in an Aesara graph—aside from RaiseOps—they don't fall under that responsibility. In general, we currently aren't trying to preserve all of the behavior of a single target—Python included—across all other transpilation targets.

Regardless, the scope of this PR—and its related issue—does not cover manual bounds checking. We can discuss it in a new issue or Discussion, though.

@aseyboldt
Copy link
Contributor Author

I'm actually a bit shocked you would accept something in aesara where we access invalid memory for wrong user input by default. I am not going to remove boundchecks from the PR, I'd feel responsible for all the headache that would lead to.

Aesara's responsibility is to faithfully preserve the results of explicitly defined computations for valid inputs when transpiling to Numba and other targets

Not sure where that is coming from. I'd say aesara's responsibility is to produce decent code, however that happens. And invalid memory access is certainly not that.

@brandonwillard
Copy link
Member

I'm actually a bit shocked you would accept something in aesara where we access invalid memory for wrong user input by default.

You seem to be aware that bounds checking already exists in Numba, but you're also assigning the same responsibility to Aesara. If you have this much disgust for a lack of bounds checking, then you need to take that up with Numba—and a few other programming languages, as well.

As I said above, if you found a bug in Numba's bounds checking that's solved by your implementation, please report it to them. We will always consider adding code that works around a current Numba bug or missing feature, but that doesn't seem to be the case here. If it is, tell us.

As you mentioned, we can override Numba's defaults and compile these graphs with bounds checking turned on by default. That's a viable approach. That's also a completely independent change; one that does not factor into the issue addressed by this PR.

I am not going to remove boundchecks from the PR, I'd feel responsible for all the headache that would lead to.

That's fine; we can take care of it.

Aesara's responsibility is to faithfully preserve the results of explicitly defined computations for valid inputs when transpiling to Numba and other targets

Not sure where that is coming from.

That's in reference to the kinds of computations that should be expressed in our Numba implementations of Aesara nodes. If you read the rest of what I wrote, you'll see how it relates to explicit error handling like the kind you've added.

I'd say aesara's responsibility is to produce decent code, however that happens. And invalid memory access is certainly not that.

Unless the code was explicitly designed/intended to prevent invalid memory accesses caused by bad user input, such an error says nothing about the quality of the code. It only says something about the purpose and/or expectations of the code. Sometimes the purpose/expectations for code involves performance, and bounds checking can hinder performance quite a bit. In that case, bounds checking would not make for decent code.

Regardless, unnecessary redundancy does not make code more decent, so, unless your additions are addressing something currently missing from Numba (as mentioned above), these changes are not more decent than the same code without the redundant bounds checking.

@aseyboldt
Copy link
Contributor Author

I kind of hope we are just talking past each other here, so I'll just summarize a bit, and hopefully that helps:

AdvancedIncSubtensor1 is using user defined indices, so I think there have to be boundschecks of some kind by default. I don't think this necessarily needs to be the case for most other ops, because most of the time we will only access invalid memory if there is a bug in aesara, but not if there is a bug in the user code.
Checked array access is the norm all over the python ecosystem (python itself, numpy, scipy, sklearn, pytorch, jax, tensorflow...). numba is the only exception I can think of right now, where it is not on by default. I don't actually agree with this default but I also think this is still much more reasonable in numba itself compared to aesara, because if you use numba directly it is much less hidden.

This is why I set boundschecks=True on the numba implementation of that op, so that we use the numba boundscheck support by default.

There is however a very common case where we can eliminate the boundschecks during graph execution and still have safe array access: If the array of indices is known at compile time we can simply pre-compute the maximum and minimum entry, and check at graph execution time if those are valid for the other input arrays.

I proposed two different implementations of that, one where this happens entirely in the numba dispatch, and one where I moved it to the graph itself using a rewrite.

Adrian Seyboldt added 3 commits August 22, 2022 19:25
We remove the default implementation for AdvancedIncSubtensor,
since it produces incorrect results for duplicate indexes.
This means we fall back to object mode for now until
we have a proper implementation of that as well.
@ricardoV94
Copy link
Contributor

ricardoV94 commented Aug 23, 2022

Sounds like numba gives us the option to enable boundchecks but doesn't do it by default. What's wrong with opting to use that in Aesara?

We can have a numba-specific Aesara flag for disabling boundchecks introduced in numba Ops, if we don't want to add a numba specific variable at the Op level.

Also what's with that check_input class variable in these Ops that doesn't seem to be used for anything?

@brandonwillard
Copy link
Member

Sounds like numba gives us the option to enable boundchecks but doesn't do it by default. What's wrong with opting to use that in Aesara?

We can have a numba-specific Aesara flag for disabling boundchecks introduced in numba Ops, if we don't want to add a numba specific variable at the Op level.

Yes, if we want bounds checking in Numba, we need to use Numba's bound checking. That's it.

Also what's with that check_input class variable in these Ops that doesn't seem to be used for anything?

I think that variable is used by COp to perform C-level checks/validation.

@aseyboldt
Copy link
Contributor Author

I give up.
Sorry @ricardoV94...

@aseyboldt aseyboldt closed this Aug 23, 2022
@ricardoV94
Copy link
Contributor

Yes, if we want bounds checking in Numba, we need to use Numba's bound checking. That's it.

Isn't this what the current PR proposed? Or are you referring to the small check inner-function?

@brandonwillard
Copy link
Member

brandonwillard commented Aug 24, 2022

Isn't this what the current PR proposed? Or are you referring to the small check inner-function?

This PR brings bounds checking to our graphs by adding a new property to AdvancedIncSubtensor1 Ops; however, we don't benefit from having bounds checking at the graph-level (i.e. via explicit CheckAndRaise nodes).

Our Python and C backends already perform bounds checking, and extra work would be needed in order to provide versions that don't (and use the newly introduced property). Likewise, if we're going to do anything with bounds checking at the graph-level, we would need to do it for all *Subtensor* Ops, and not just AdvancedIncSubtensor1.

Also, has anyone considered how adding a new property like that to AdvancedIncSubtensor1 affects other operations (e.g. node merging)?

Anyway, the approach in this PR consists of a much bigger set of changes than we need. We only need a simple implementation in the spirit of

@numba_njit(boundscheck=boundscheck)
def advancedincsubtensor1(x, vals, idxs):
    for idx, val in zip(idxs, vals):
        x[idx] += val
    return x

where boundscheck is potentially pulled from an aesara.config option or something similar. At the very least, we don't want to override a user's local/environment's Numba config options, so, if a user sets NUMBA_BOUNDSCHECK, Aesara should honor that value. That's a hard requirement.

Also, if we're adding bounds checking for AdvancedIncSubtensor, then we need to add it to all the other Numba *Subtensor* implementations. That's one of the reasons why this is a distinct issue that should be addressed in another PR.

@brandonwillard
Copy link
Member

I've create #1143 to cover the issue associated with this PR. We can address the default bounds checking after that. In the meantime, if anyone wants (or ever wanted) bounds checking in Numba, they should be able to set NUMBA_BOUNDSCHECK, as Numba itself advises doing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working important Numba Involves Numba transpilation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Numba implementation of AdvancedIncSubtensor doesn't handle duplicate indices correctly
3 participants