Skip to content

Commit

Permalink
【Hackathon 5th No.34】 为 Paddle 新增 bitwise_right_shift / bitwise_right…
Browse files Browse the repository at this point in the history
…_shift_ / bitwise_left_shift / bitwise_left_shift_ API (update RFC) (#788)

* update bitwise shift rfc

* add details

* add notice for overflow

* update

* typo
  • Loading branch information
cocoshe committed Jan 5, 2024
1 parent f867979 commit b109db8
Showing 1 changed file with 210 additions and 12 deletions.
222 changes: 210 additions & 12 deletions rfcs/APIs/20230927_api_design_for_bitwise_shift.md
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# paddle.pdist设计文档
# paddle.bitwise_left_shift/paddle.bitwise_right_shift设计文档

| API 名称 | paddle.bitwise_right_shift<br />paddle.bitwise_left_shift |
| ------------ | --------------------------------------------------------- |
Expand Down Expand Up @@ -53,14 +53,104 @@ shift_right_arithmetic = _make_elementwise_binary_prim(
shift_right_logical = _not_impl # 可见pytorch中仅支持算数位移
```

具体元素尺度的实现,[代码位置](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/common.py#L401-L405)
具体元素尺度的实现,

```python
@staticmethod
def bitwise_right_shift(a, b):
return f"decltype({a})({a} >> {b})"
[左移 cpu kernel](https://github.com/pytorch/pytorch/blob/3747aca49a39479c2c5e223b91369db5bd339cdf/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L423-L437)

```cpp
void lshift_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t {
constexpr scalar_t max_shift = sizeof(scalar_t) * CHAR_BIT;
if ((static_cast<std::make_signed_t<scalar_t>>(b) < 0) ||
(b >= max_shift)) {
return 0;
}
return static_cast<std::make_unsigned_t<scalar_t>>(a) << b;
},
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return a << b; });
});
}
```
[左移 cuda kernel](https://github.com/pytorch/pytorch/blob/6e1ba79b7fdf3d66db8fb69462fb502e5006e5e7/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu#L14-L25)
```cpp
void lshift_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_cuda", [&]() {
gpu_kernel_with_scalars(iter,
[]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
constexpr scalar_t max_shift = sizeof(scalar_t) * CHAR_BIT;
if ((static_cast<std::make_signed_t<scalar_t>>(b) < 0) || (b >= max_shift)) {
return 0;
}
return static_cast<std::make_unsigned_t<scalar_t>>(a) << b;
});
});
}
```

+ 可以发现,在算术左移时,kernel需要针对[两类情况进行处理](https://wiki.sei.cmu.edu/confluence/display/c/INT34-C.+Do+not+shift+an+expression+by+a+negative+number+of+bits+or+by+greater+than+or+equal+to+the+number+of+bits+that+exist+in+the+operand)

+ `b`移动的距离大于等于当前类型的位数时(例如对int16左移16位),则直接返回0(若进行移动,编译器会在此时发生取模优化,例如左移1000位时,实际上会移动1000%16=8位,但实际上需要返回0,表示溢出)
+ `b`为负数时,在C语言标准中为"未定义行为",认为等效于左移了无穷位,直接返回0;

另外,kernel中用`std::make_signed_t<scalar_t>>(b)``b`强转为有符号数,若`b`原本就是有符号数,无影响;若`b`原本是无符号数,且最高位为0,无影响;若`b`原本是无符号数,而且较大,最高位为`1`,强转后为负数,小于0。(不过感觉即使不强转,最高位为1的无符号数应该也会令`(b >= max_shift)`为true)



[右移 cpu kernel](https://github.com/pytorch/pytorch/blob/3747aca49a39479c2c5e223b91369db5bd339cdf/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L494-L511)

```cpp
void rshift_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t {
// right shift value to retain sign bit for signed and no bits for
// unsigned
constexpr scalar_t max_shift =
sizeof(scalar_t) * CHAR_BIT - std::is_signed_v<scalar_t>;
if ((static_cast<std::make_signed_t<scalar_t>>(b) < 0) ||
(b >= max_shift)) {
return a >> max_shift;
}
return a >> b;
},
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return a >> b; });
});
}
```
[右移 cuda kernel](https://github.com/pytorch/pytorch/blob/6e1ba79b7fdf3d66db8fb69462fb502e5006e5e7/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu#L27-L39C2)
```cpp
void rshift_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_cuda", [&]() {
gpu_kernel_with_scalars(iter,
[]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
// right shift value to retain sign bit for signed and no bits for unsigned
constexpr scalar_t max_shift = sizeof(scalar_t) * CHAR_BIT - std::is_signed_v<scalar_t>;
if ((static_cast<std::make_signed_t<scalar_t>>(b) < 0) || (b >= max_shift)) {
return a >> max_shift;
}
return a >> b;
});
});
}
```

+ 算术右移时,`max_shift`需要考虑最大的移动距离,有符号数最高位为符号位,故表示数值的位数实际上会少一位。
+ 有符号数时,例如`int8 x=-100`,补码为`1001,1100`,最高位为符号位,仅需要右移7位,所有的`int8`就都会变成`1111,1111`,即`-1`
+ 无符号数时候,例如`uint8 x=200`,存储为`1100,1000`,八位均表示数值大小,需要右移8位才可以将所有的`uint8`变为`0000,0000`,即`0`
+`b`为负数这一未定义行为时,同样等效于右移无穷位,与移动`max_shift`等效(从代码中可以看到,`b<0``b>=max_shift`是在同一个`if`判断中),只要满足两个条件中任意一个,则使得有符号数变为`-1`,无符号数变为`0`

**在paddle API的设计过程中,也按照这样的方式来实现,当`b`为负数或者移动超过最大值,则使得有符号数变为`-1`,无符号数变为`0`**





## Numpy
Expand Down Expand Up @@ -125,6 +215,27 @@ NPY_NO_EXPORT void NPY_CPU_DISPATCH_CURFX(@TYPE@_right_shift)
}
```
`npy_lshift`[相关调用](https://github.com/numpy/numpy/blob/0032ede015c9b06f88cc7f9b07138ce35f4357ae/numpy/_core/src/npymath/npy_math_internal.h.src#L653-L662):
```cpp
NPY_INPLACE npy_@u@@type@
npy_lshift@u@@c@(npy_@u@@type@ a, npy_@u@@type@ b)
{
if (NPY_LIKELY((size_t)b < sizeof(a) * CHAR_BIT)) {
return a << b;
}
else {
return 0;
}
}
```

+ 在左移时,为了防止编译器对位移的自动取模优化(例如int16类型左移100位,实际上被自动优化成左移`100%16=4`位),导致结果不为0(溢出);

而且这里将`b`转为`size_t`,而`size_t`是unsigned类型,所以当`b`为有符号负数时,由于补码最高位的符号位为1,所以会被转换成一个很大的正数,必然超过`sizeof(a) * CHAR_BIT`的大小,所以直接走else返回0,这里应该与`b < 0`实现了同样的效果。



`npy_rshift`相关调用

```cpp
Expand All @@ -145,6 +256,14 @@ npy_rshift@u@@c@(npy_@u@@type@ a, npy_@u@@type@ b)
}
```

+ 在右移时,右移的最大位数限制需要区分有符号数和无符号数:

**此处实现与pytorch中的实现略有不同,不过结果还是等效的:pytorch中认为,有符号数最大右移位数为`n_bit-1`,而无符号数最大右移位数为`n_bit`,例如(int16最多右移15位,uint16最多右移16位,否则触发溢出,全置为符号位);numpy中没有刻意限定符号数和无符号数的最大位移位数(例如int16和uint16的最大位移位数都是16位,都是16位才出发溢出),由于对于有符号数例如int16来说,“(pytorch)右移15位触发溢出,全部置为符号位”与“(numpy)右移15位”,两者结果是一样的,只是前者直接走溢出的else,后者真正去做了位运算而已,所以还是等效**



这里的`NPY_LIKELY((size_t)b`与左移一样,隐含了`b`需要大于0。若`b`小于0,则转unsigned之后大小必然大于`sizeof(a) * CHAR_BIT`溢出,而后又根据`a`的符号位作为返回(负数溢出补码为`1111,1111,...1111`,也就是-1,正数和无符号数溢出为0)。



## Jax
Expand All @@ -157,8 +276,8 @@ npy_rshift@u@@c@(npy_@u@@type@ a, npy_@u@@type@ b)

Parameters

- **x** ([`Union`](https://docs.python.org/3/library/typing.html#typing.Union)[[`Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array), [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray), [`bool_`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool_), [`number`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.number.html#jax.numpy.number), [`bool`](https://docs.python.org/3/library/functions.html#bool), [`int`](https://docs.python.org/3/library/functions.html#int), [`float`](https://docs.python.org/3/library/functions.html#float), [`complex`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.complex.html#jax.lax.complex)])
- **y** ([`Union`](https://docs.python.org/3/library/typing.html#typing.Union)[[`Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array), [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray), [`bool_`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool_), [`number`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.number.html#jax.numpy.number), [`bool`](https://docs.python.org/3/library/functions.html#bool), [`int`](https://docs.python.org/3/library/functions.html#int), [`float`](https://docs.python.org/3/library/functions.html#float), [`complex`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.complex.html#jax.lax.complex)])
- **x** ([`Union`](https://docs.python.org/3/library/typing.html#typing.Union)[[`Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array), [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray), [`bool_`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool_), [`number`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.number.html#jax.numpy.number), [`bool`](https://docs.python.org/3/library/functions.html#bool), [`int`](https://docs.python.org/3/library/functions.html#int), [`float`](https://docs.python.org/3/library/functions.html#float), [`complex`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.complex.html#jax.lax.complex)])
- **y** ([`Union`](https://docs.python.org/3/library/typing.html#typing.Union)[[`Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array), [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray), [`bool_`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool_), [`number`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.number.html#jax.numpy.number), [`bool`](https://docs.python.org/3/library/functions.html#bool), [`int`](https://docs.python.org/3/library/functions.html#int), [`float`](https://docs.python.org/3/library/functions.html#float), [`complex`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.complex.html#jax.lax.complex)])

Return type[`Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array)

Expand Down Expand Up @@ -190,8 +309,8 @@ def _shift_right_arithmetic_raw(x, y):

Elementwise logical right shift: x ≫ y.Parameters

- **x** ([`Union`](https://docs.python.org/3/library/typing.html#typing.Union)[[`Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array), [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray), [`bool_`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool_), [`number`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.number.html#jax.numpy.number), [`bool`](https://docs.python.org/3/library/functions.html#bool), [`int`](https://docs.python.org/3/library/functions.html#int), [`float`](https://docs.python.org/3/library/functions.html#float), [`complex`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.complex.html#jax.lax.complex)])
- **y** ([`Union`](https://docs.python.org/3/library/typing.html#typing.Union)[[`Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array), [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray), [`bool_`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool_), [`number`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.number.html#jax.numpy.number), [`bool`](https://docs.python.org/3/library/functions.html#bool), [`int`](https://docs.python.org/3/library/functions.html#int), [`float`](https://docs.python.org/3/library/functions.html#float), [`complex`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.complex.html#jax.lax.complex)])
- **x** ([`Union`](https://docs.python.org/3/library/typing.html#typing.Union)[[`Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array), [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray), [`bool_`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool_), [`number`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.number.html#jax.numpy.number), [`bool`](https://docs.python.org/3/library/functions.html#bool), [`int`](https://docs.python.org/3/library/functions.html#int), [`float`](https://docs.python.org/3/library/functions.html#float), [`complex`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.complex.html#jax.lax.complex)])
- **y** ([`Union`](https://docs.python.org/3/library/typing.html#typing.Union)[[`Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array), [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray), [`bool_`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool_), [`number`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.number.html#jax.numpy.number), [`bool`](https://docs.python.org/3/library/functions.html#bool), [`int`](https://docs.python.org/3/library/functions.html#int), [`float`](https://docs.python.org/3/library/functions.html#float), [`complex`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.complex.html#jax.lax.complex)])

Return type[`Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array)

Expand Down Expand Up @@ -225,7 +344,9 @@ PyTorch是将算子注册到element wise系列中,Numpy也类似地`BINARY_LOO

同时,PyTorch与Numpy中都仅支持算术位移,不支持逻辑位移,而JAX中实现了算术位移和逻辑位移。

PyTorch和numpy的处理基本一致,在前面Numpy的调研部分详细说明了两者略微的差异,这个差异不影响最终结果,两者处理的思路都是一致的。

面对第二个参数`b`为负数的时候,都是将其等效为位移无穷大的距离(这两个判断条件在同一个`if`中,用“或”逻辑连接),处理方式都是使有符号数时变为`-1`,无符号数时变为`0`

# 五、设计思路与实现方案

Expand All @@ -235,7 +356,84 @@ API的设计为`paddle.bitwise_right_shift(x, y, is_arithmetic=True)`,其余

## API实现方案

参考`PyTorch``Numpy``JAX`中的设计,组合已有API实现功能

由于python层相关API的类型支持需求不合理(例如jax中的设计,unsigned转signed会溢出),考虑下沉到cpp层实现。以右移为例,python层接口为`paddle.bitwise_right_shift`,通过参数`is_arithmetic`的设置来调用算术位移或逻辑位移的kernel,若为算术位移,则调用`_C_ops.bitwise_left_shift_arithmetic_(x, y)`,若为逻辑位移,则调用`_C_ops.bitwise_left_shift_logic_(x, y)`

cpp的kernel实现主要通过elementwise的方法,与`bitwise_and`等bitwise op设计类似,复用elementwise相关代码以支持broadcast、具体Functor的调用等。



具体行为定义:(`n_bits`表示数据类型存储位数,例如int8的`n_bits`为8,uint16的`n_bits`为16;当`y`小于0时为“未定义行为”,等效于位移超过最大位数溢出)

+ 算术位移

+ 算术左移:当`y`小于0,或者`y`大于等于`n_bits`时候溢出,返回0;否则正常位移;
+ 算术右移:
+ 有符号数时:当`y`小于0,或者`y`大于等于`n_bits`时候溢出,返回符号位(`a>>(n_bits-1)&1`);否则正常位移;
+ 无符号数时:当`y`小于0,或者`y`大于等于`n_bits`时候溢出,返回0;否则正常位移;

+ 逻辑位移

+ 逻辑左移:当`y`小于0,或者`y`大于等于`n_bits`时候溢出,返回0;否则正常位移;

+ 逻辑右移:

+ 有符号数时:当`y`小于0,或者`y`大于等于`n_bits`时候溢出,返回0;否则特殊位移:

```cpp
template <typename T>
HOSTDEVICE T logic_shift_func(const T a, const T b) {
if (b < static_cast<T>(0) || b >= static_cast<T>(sizeof(T) * 8))
return static_cast<T>(0);
T t = static_cast<T>(sizeof(T) * 8 - 1);
T mask = (((a >> t) << t) >> b) << 1;
return (a >> b) ^ mask;
}
```

在`T mask = (((a >> t) << t) >> b) << 1;`中,先`(a >> t)`取符号位,然后`<< t`回到原位,再右移`b`后左移一位,最后与`a>>b`的结果做亦或,下面举两个例子:

```
example1:
a = 1001,1010 b = 3, 有t=7
((a>>t)<<t) = 1000,0000
mask=(((a>>t)<<t)>>b)<<1 = 1110,0000
a>>b = 1111,0011
所以 (a>>b) ^ mask = 0001,0011
example2:
a = 0001,1010 b = 3, 有t=7
((a>>t)<<t) = 0000,0000
mask=(((a>>t)<<t)>>b)<<1 = 0000,0000
a>>b = 0000,0011
所以 (a>>b) ^ mask = 0000,0011
```

+ 无符号数时:当`y`小于0,或者`y`大于等于`n_bits`时候溢出,返回0;否则正常位移;

以上行为中,算术位移与numpy、pytorch的实现对齐;由于numpy和pytorch不支持逻辑位移,所以逻辑位移参考jax的实现思路,用numpy来进行间接实现和验证。


+ 关于**有符号数的符号位**在不同情景下的行为:
1. 算术左移时,符号位同其他位一样,一起左移,右边补0;
2. 逻辑左移时,符号位同其他位一样,一起左移,右边补0;
3. 算术右移时,符号位同其他位一样,一起右移,左边补符号位;
4. 逻辑右移时,符号位同其他位一样,一起右移,左边补0;

注意:当有符号数左移发生溢出时,其值不可控,可能会在左移时突然变号,这是因为在左移时,有符号数的符号位同样进行左移,会导致符号位右侧的值不断成为符号位,例如
```
example1:
int8_t x = -45; // 补码为 1101,0011 表示-45
int8_t y = x << 2; //补码为 0100,1100 表示76
int8_t z = x << 3; //补码为 1001,1000 表示-104
example2:
int8_t x = -86; // 补码为 1010,1010 表示-86
int8_t y = x << 1; //补码为 0101,0100 表示84
int8_t z = x << 2; //补码为 1010,1000 表示-88
```
以上为溢出导致的符号突变。


# 六、测试和验收的考量

Expand All @@ -261,4 +459,4 @@ API的设计为`paddle.bitwise_right_shift(x, y, is_arithmetic=True)`,其余

[PyTorch文档](https://pytorch.org/docs/stable/generated/torch.bitwise_right_shift.html?highlight=bitwise_right_shift#torch.bitwise_right_shift)

[Numpy文档](https://numpy.org/doc/stable/reference/generated/numpy.right_shift.html#numpy.right_shift)
[Numpy文档](https://numpy.org/doc/stable/reference/generated/numpy.right_shift.html#numpy.right_shift)

0 comments on commit b109db8

Please sign in to comment.