Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.34】 为 Paddle 新增 bitwise_right_shift / bitwise_right_shift_ / bitwise_left_shift / bitwise_left_shift_ API (update RFC) #788

Merged
merged 5 commits into from
Jan 5, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
195 changes: 184 additions & 11 deletions rfcs/APIs/20230927_api_design_for_bitwise_shift.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,102 @@ shift_right_arithmetic = _make_elementwise_binary_prim(
shift_right_logical = _not_impl # 可见pytorch中仅支持算数位移
```

具体元素尺度的实现,[代码位置](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/common.py#L401-L405):
具体元素尺度的实现,

```python
@staticmethod
def bitwise_right_shift(a, b):
return f"decltype({a})({a} >> {b})"
[左移 cpu kernel](https://github.com/pytorch/pytorch/blob/3747aca49a39479c2c5e223b91369db5bd339cdf/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L423-L437):

```cpp
void lshift_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t {
constexpr scalar_t max_shift = sizeof(scalar_t) * CHAR_BIT;
if ((static_cast<std::make_signed_t<scalar_t>>(b) < 0) ||
(b >= max_shift)) {
return 0;
}
return static_cast<std::make_unsigned_t<scalar_t>>(a) << b;
},
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return a << b; });
});
}
```

[左移 cuda kernel](https://github.com/pytorch/pytorch/blob/6e1ba79b7fdf3d66db8fb69462fb502e5006e5e7/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu#L14-L25)

```cpp
void lshift_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_cuda", [&]() {
gpu_kernel_with_scalars(iter,
[]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
constexpr scalar_t max_shift = sizeof(scalar_t) * CHAR_BIT;
if ((static_cast<std::make_signed_t<scalar_t>>(b) < 0) || (b >= max_shift)) {
return 0;
}
return static_cast<std::make_unsigned_t<scalar_t>>(a) << b;
});
});
}
```

+ 可以发现,在算术左移时,kernel需要针对[两类情况进行处理](https://wiki.sei.cmu.edu/confluence/display/c/INT34-C.+Do+not+shift+an+expression+by+a+negative+number+of+bits+or+by+greater+than+or+equal+to+the+number+of+bits+that+exist+in+the+operand):

+ `b`移动的距离大于等于当前类型的位数时(例如对int16左移16位),则直接返回0(若进行移动,编译器会在此时发生取模优化,例如左移1000位时,实际上会移动1000%16=8位,但实际上需要返回0,表示溢出)
+ `b`为负数时,在C语言标准中为"未定义行为",认为等效于左移了无穷位,直接返回0;

另外,kernel中用`std::make_signed_t<scalar_t>>(b)`把`b`强转为有符号数,若`b`原本就是有符号数,无影响;若`b`原本是无符号数,且最高位为0,无影响;若`b`原本是无符号数,而且较大,最高位为`1`,强转后为负数,小于0。(不过感觉即使不强转,最高位为1的无符号数应该也会令`(b >= max_shift)`为true)



[右移 cpu kernel](https://github.com/pytorch/pytorch/blob/3747aca49a39479c2c5e223b91369db5bd339cdf/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L494-L511)

```cpp
void rshift_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t {
// right shift value to retain sign bit for signed and no bits for
// unsigned
constexpr scalar_t max_shift =
sizeof(scalar_t) * CHAR_BIT - std::is_signed_v<scalar_t>;
if ((static_cast<std::make_signed_t<scalar_t>>(b) < 0) ||
(b >= max_shift)) {
return a >> max_shift;
}
return a >> b;
},
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return a >> b; });
});
}
```

[右移 cuda kernel](https://github.com/pytorch/pytorch/blob/6e1ba79b7fdf3d66db8fb69462fb502e5006e5e7/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu#L27-L39C2)

```cpp
void rshift_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_cuda", [&]() {
gpu_kernel_with_scalars(iter,
[]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
// right shift value to retain sign bit for signed and no bits for unsigned
constexpr scalar_t max_shift = sizeof(scalar_t) * CHAR_BIT - std::is_signed_v<scalar_t>;
if ((static_cast<std::make_signed_t<scalar_t>>(b) < 0) || (b >= max_shift)) {
return a >> max_shift;
}
return a >> b;
});
});
}
```

+ 算术右移时,`max_shift`需要考虑最大的移动距离,有符号数最高位为符号位,故表示数值的位数实际上会少一位。
+ 有符号数时,例如`int8 x=-100`,补码为`1001,1100`,最高位为符号位,仅需要右移7位,所有的`int8`就都会变成`1111,1111`,即`-1`;
+ 无符号数时候,例如`uint8 x=200`,存储为`1100,1000`,八位均表示数值大小,需要右移8位才可以将所有的`uint8`变为`0000,0000`,即`0`;
+ 当`b`位负数这一未定义行为时,同样等效于右移无穷位,与移动`max_shift`等效,有符号数变为`-1`,无符号数变为`0`

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

错别字,“位“应该为”为“。对于未定义的行为怎么处理,应该列出竞品实际是如何处理的,处理方式是否一致,这些都是要说清楚的。





## Numpy
Expand Down Expand Up @@ -125,6 +213,27 @@ NPY_NO_EXPORT void NPY_CPU_DISPATCH_CURFX(@TYPE@_right_shift)
}
```

`npy_lshift`[相关调用](https://github.com/numpy/numpy/blob/0032ede015c9b06f88cc7f9b07138ce35f4357ae/numpy/_core/src/npymath/npy_math_internal.h.src#L653-L662):

```cpp
NPY_INPLACE npy_@u@@type@
npy_lshift@u@@c@(npy_@u@@type@ a, npy_@u@@type@ b)
{
if (NPY_LIKELY((size_t)b < sizeof(a) * CHAR_BIT)) {
return a << b;
}
else {
return 0;
}
}
```

+ 在左移时,为了防止编译器对位移的自动取模优化(例如int16类型左移100位,实际上被自动优化成左移`100%16=4`位),导致结果不为0(溢出);

而且这里将`b`转为`size_t`,而`size_t`是unsigned类型,所以当`b`为有符号负数时,由于补码最高位的符号位为1,所以会被转换成一个很大的正数,必然超过`sizeof(a) * CHAR_BIT`的大小,所以直接走else返回0,这里应该与`b < 0`实现了同样的效果。



`npy_rshift`相关调用

```cpp
Expand All @@ -145,6 +254,14 @@ npy_rshift@u@@c@(npy_@u@@type@ a, npy_@u@@type@ b)
}
```

+ 在右移时,右移的最大位数限制需要区分有符号数和无符号数:

**此处实现与pytorch中的实现略有不同,不过结果还是等效的:pytorch中认为,有符号数最大右移位数为`n_bit-1`,而无符号数最大右移位数为`n_bit`,例如(int16最多右移15位,uint16最多右移16位,否则触发溢出,全置为符号位);numpy中没有刻意限定符号数和无符号数的最大位移位数(例如int16和uint16的最大位移位数都是16位,都是16位才出发溢出),由于对于有符号数例如int16来说,“(pytorch)右移15位触发溢出,全部置为符号位”与“(numpy)右移15位”,两者结果是一样的,只是前者直接走溢出的else,后者真正去做了位运算而已,所以还是等效**



这里的`NPY_LIKELY((size_t)b`与左移一样,隐含了`b`需要大于0。若`b`小于0,则转unsigned之后大小必然大于`sizeof(a) * CHAR_BIT`溢出,而后又根据`a`的符号位作为返回(负数溢出补码为`1111,1111,...1111`,也就是-1,正数和无符号数溢出为0)。



## Jax
Expand All @@ -157,8 +274,8 @@ npy_rshift@u@@c@(npy_@u@@type@ a, npy_@u@@type@ b)

Parameters

- **x** ([`Union`](https://docs.python.org/3/library/typing.html#typing.Union)[[`Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array), [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray), [`bool_`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool_), [`number`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.number.html#jax.numpy.number), [`bool`](https://docs.python.org/3/library/functions.html#bool), [`int`](https://docs.python.org/3/library/functions.html#int), [`float`](https://docs.python.org/3/library/functions.html#float), [`complex`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.complex.html#jax.lax.complex)])
- **y** ([`Union`](https://docs.python.org/3/library/typing.html#typing.Union)[[`Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array), [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray), [`bool_`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool_), [`number`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.number.html#jax.numpy.number), [`bool`](https://docs.python.org/3/library/functions.html#bool), [`int`](https://docs.python.org/3/library/functions.html#int), [`float`](https://docs.python.org/3/library/functions.html#float), [`complex`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.complex.html#jax.lax.complex)])
- **x** ([`Union`](https://docs.python.org/3/library/typing.html#typing.Union)[[`Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array), [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray), [`bool_`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool_), [`number`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.number.html#jax.numpy.number), [`bool`](https://docs.python.org/3/library/functions.html#bool), [`int`](https://docs.python.org/3/library/functions.html#int), [`float`](https://docs.python.org/3/library/functions.html#float), [`complex`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.complex.html#jax.lax.complex)])
- **y** ([`Union`](https://docs.python.org/3/library/typing.html#typing.Union)[[`Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array), [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray), [`bool_`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool_), [`number`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.number.html#jax.numpy.number), [`bool`](https://docs.python.org/3/library/functions.html#bool), [`int`](https://docs.python.org/3/library/functions.html#int), [`float`](https://docs.python.org/3/library/functions.html#float), [`complex`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.complex.html#jax.lax.complex)])

Return type[`Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array)

Expand Down Expand Up @@ -190,8 +307,8 @@ def _shift_right_arithmetic_raw(x, y):

Elementwise logical right shift: x ≫ y.Parameters

- **x** ([`Union`](https://docs.python.org/3/library/typing.html#typing.Union)[[`Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array), [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray), [`bool_`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool_), [`number`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.number.html#jax.numpy.number), [`bool`](https://docs.python.org/3/library/functions.html#bool), [`int`](https://docs.python.org/3/library/functions.html#int), [`float`](https://docs.python.org/3/library/functions.html#float), [`complex`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.complex.html#jax.lax.complex)])
- **y** ([`Union`](https://docs.python.org/3/library/typing.html#typing.Union)[[`Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array), [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray), [`bool_`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool_), [`number`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.number.html#jax.numpy.number), [`bool`](https://docs.python.org/3/library/functions.html#bool), [`int`](https://docs.python.org/3/library/functions.html#int), [`float`](https://docs.python.org/3/library/functions.html#float), [`complex`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.complex.html#jax.lax.complex)])
- **x** ([`Union`](https://docs.python.org/3/library/typing.html#typing.Union)[[`Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array), [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray), [`bool_`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool_), [`number`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.number.html#jax.numpy.number), [`bool`](https://docs.python.org/3/library/functions.html#bool), [`int`](https://docs.python.org/3/library/functions.html#int), [`float`](https://docs.python.org/3/library/functions.html#float), [`complex`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.complex.html#jax.lax.complex)])
- **y** ([`Union`](https://docs.python.org/3/library/typing.html#typing.Union)[[`Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array), [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray), [`bool_`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool_), [`number`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.number.html#jax.numpy.number), [`bool`](https://docs.python.org/3/library/functions.html#bool), [`int`](https://docs.python.org/3/library/functions.html#int), [`float`](https://docs.python.org/3/library/functions.html#float), [`complex`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.complex.html#jax.lax.complex)])

Return type[`Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array)

Expand Down Expand Up @@ -235,7 +352,63 @@ API的设计为`paddle.bitwise_right_shift(x, y, is_arithmetic=True)`,其余

## API实现方案

参考`PyTorch`、`Numpy`、`JAX`中的设计,组合已有API实现功能

由于python层相关API的类型支持需求不合理(例如jax中的设计,unsigned转signed会溢出),考虑下沉到cpp层实现。以右移为例,python层接口为`paddle.bitwise_right_shift`,通过参数`is_arithmetic`的设置来调用算术位移或逻辑位移的kernel,若为算术位移,则调用`_C_ops.bitwise_left_shift_arithmetic_(x, y)`,若为逻辑位移,则调用`_C_ops.bitwise_left_shift_logic_(x, y)`

cpp的kernel实现主要通过elementwise的方法,与`bitwise_and`等bitwise op设计类似,复用elementwise相关代码以支持broadcast、具体Functor的调用等。



具体行为定义:(`n_bits`表示数据类型存储位数,例如int8的`n_bits`为8,uint16的`n_bits`为16;当`y`小于0时为“未定义行为”,等效于位移超过最大位数溢出)

+ 算术位移

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意到一种情形,比如x为有符号数,比如x=-45(int8),无论是算术移位还是逻辑移位,其左移2位为76,左移3位为-104。这种有符号数左移出现时正时负的情形,可能会让人迷惑,RFC中还可以详细描述下不同情形下,对于有符号数x的符号位的处理逻辑,相应的例子和描述也可以更新到API文档中。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意到一种情形,比如x为有符号数,比如x=-45(int8),无论是算术移位还是逻辑移位,其左移2位为76,左移3位为-104。这种有符号数左移出现时正时负的情形,可能会让人迷惑,RFC中还可以详细描述下不同情形下,对于有符号数x的符号位的处理逻辑,相应的例子和描述也可以更新到API文档中。

嗯嗯好,这种属于溢出,位移溢出的结果都是不可控无意义的,我说明一下~

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

补充下竞品对于未定义操作的的实际处理方式,看是否和当前现有设计方案是一致的


+ 算术左移:当`y`小于0,或者`y`大于等于`n_bits`时候溢出,返回0;否则正常位移;
+ 算术右移:
+ 有符号数时:当`y`小于0,或者`y`大于等于`n_bits`时候溢出,返回符号位(`a>>(n_bits-1)&1`);否则正常位移;
+ 无符号数时:当`y`小于0,或者`y`大于等于`n_bits`时候溢出,返回0;否则正常位移;

+ 逻辑位移

+ 逻辑左移:当`y`小于0,或者`y`大于等于`n_bits`时候溢出,返回0;否则正常位移;

+ 逻辑右移:

+ 有符号数时:当`y`小于0,或者`y`大于等于`n_bits`时候溢出,返回0;否则特殊位移:

```cpp
template <typename T>
HOSTDEVICE T logic_shift_func(const T a, const T b) {
if (b < static_cast<T>(0) || b >= static_cast<T>(sizeof(T) * 8))
return static_cast<T>(0);
T t = static_cast<T>(sizeof(T) * 8 - 1);
T mask = (((a >> t) << t) >> b) << 1;
return (a >> b) ^ mask;
}
```

在`T mask = (((a >> t) << t) >> b) << 1;`中,先`(a >> t)`取符号位,然后`<< t`回到原位,再右移`b`后左移一位,最后与`a>>b`的结果做亦或,下面举两个例子:

```
example1:
a = 1001,1010 b = 3, 有t=7
((a>>t)<<t) = 1000,0000
mask=(((a>>t)<<t)>>b)<<1 = 1110,0000
a>>b = 1111,0011
所以 (a>>b) ^ mask = 0001,0011

example2:
a = 0001,1010 b = 3, 有t=7
((a>>t)<<t) = 0000,0000
mask=(((a>>t)<<t)>>b)<<1 = 0000,0000
a>>b = 0000,0011
所以 (a>>b) ^ mask = 0000,0011
```

+ 无符号数时:当`y`小于0,或者`y`大于等于`n_bits`时候溢出,返回0;否则正常位移;

以上行为中,算术位移与numpy、pytorch的实现对齐;由于numpy和pytorch不支持逻辑位移,所以逻辑位移参考jax的实现思路,用numpy来进行间接实现和验证。


# 六、测试和验收的考量

Expand All @@ -261,4 +434,4 @@ API的设计为`paddle.bitwise_right_shift(x, y, is_arithmetic=True)`,其余

[PyTorch文档](https://pytorch.org/docs/stable/generated/torch.bitwise_right_shift.html?highlight=bitwise_right_shift#torch.bitwise_right_shift)

[Numpy文档](https://numpy.org/doc/stable/reference/generated/numpy.right_shift.html#numpy.right_shift)
[Numpy文档](https://numpy.org/doc/stable/reference/generated/numpy.right_shift.html#numpy.right_shift)