Skip to content

Commit

Permalink
【Hackathon 5th No.28】为 Paddle 新增 slice_scatter API v2 (#790)
Browse files Browse the repository at this point in the history
* [Add] hack5 28 rfc

* [Change] use set_value op

* [Fix] layer name

* [Change] values to value

* [Change] slice_scatter with list of int params

* [Change] test
  • Loading branch information
megemini committed Dec 26, 2023
1 parent 4a042c8 commit 3b0188b
Showing 1 changed file with 19 additions and 61 deletions.
80 changes: 19 additions & 61 deletions rfcs/APIs/20231206_api_design_for_slice_scatter.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
| API名称 | paddle.slice_scatter |
| ------------------------------------------------------------ | ----------------------------------------- |
| 提交作者<input type="checkbox" class="rowselector hidden"> | megemini (柳顺) |
| 提交时间<input type="checkbox" class="rowselector hidden"> | 2023-12-13 |
| 版本号 | V1.0 |
| 提交时间<input type="checkbox" class="rowselector hidden"> | 2023-12-22 |
| 版本号 | V2.0 |
| 依赖飞桨版本<input type="checkbox" class="rowselector hidden"> | develop |
| 文件名 | 20231213_api_design_for_slice_scatter.md<br> |

**修订记录**

v2.0 修改函数签名,支持 `list of int` 的参数


# 一、概述
Expand Down Expand Up @@ -266,17 +269,17 @@ paddle 目前的 `set_value` 算子已经支持 `axes`, `starts`, `ends`, `steps

添加 Python API:
```python
paddle.slice_scatter(x, value, axis=0, start=None, stop=None, step=1, name=None)
paddle.slice_scatter(x, value, axes, starts, ends, strides, name=None)
```

参数表:

- x: (Tensor) 输入的 tensor。数据类型支持 `float32``float64`
- value: (Tensor) 用于填充的 tensor。数据类型与input一致,形状与`x[*x.shape[:axis], start:end:step, *x.shape[axis+1:]]`取出的slice一致
- axis: (int) y的数据将被填充至x的axis维度。
- start: (Optional[int]) 待插入slice位置的起始index。
- stop: (Optional[int]) 待插入slice位置的结束index。
- step: (int) 待插入slice的步长。
- x: (Tensor) 输入的 tensor。
- value: (Tensor) 用于填充的 tensor。数据类型与input一致。
- axes: (list|tuple) y的数据将被填充至x的axis维度。
- starts: (list|tuple) 待插入slice位置的起始index。
- ends: (list|tuple) 待插入slice位置的结束index。
- strides: (list|tuple) 待插入slice的步长。
- name: (Optional[str]) op 名称

## 底层OP设计
Expand All @@ -288,53 +291,16 @@ paddle.slice_scatter(x, value, axis=0, start=None, stop=None, step=1, name=None)
此次使用 `set_value` 算子实现接口:

``` python
def slice_scatter(x, value, axis=0, start=None, stop=None, step=1, name=None):

if x.ndim != value.ndim:
raise ValueError(
f"The input x and value should have save dimension, but got input of {x.ndim} and value of {value.ndim}."
)

x_shape = x.shape
value_shape = value.shape

index = list(range(start or 0, stop or x_shape[axis], step))
exp_shape = [*x_shape[:axis], len(index), *x_shape[axis+1:]]
if exp_shape != value_shape:
raise ValueError(
"The value.shape should be same of [*x_shape[:axis], len(index), *x_shape[axis+1:]],"
f"but got value.shape of {value.shape} and slice shape {exp_shape}."
)

starts = [start]
ends = [stop]
steps = [step]
axes = [axis]
none_axes = []
decrease_axes = []
inputs = {'Input': x}
attrs = {
'axes': axes,
'starts': starts,
'ends': ends,
'steps': steps,
'decrease_axes': decrease_axes,
'none_axes': none_axes,
}

dtype = x.dtype
attrs['dtype'] = dtype

value = value.astype(dtype)
inputs["ValueTensor"] = value
def slice_scatter(x, value, axes, starts, ends, strides, name=None):
... check params

if in_dynamic_or_pir_mode():
return _C_ops.set_value_with_tensor(
x,
value,
starts,
ends,
steps,
strides,
axes,
decrease_axes,
none_axes,
Expand All @@ -354,22 +320,14 @@ def slice_scatter(x, value, axis=0, start=None, stop=None, step=1, name=None):
return output
```

有几点说明:

- x 与 src 需要有相同的 ndim
- values_shape 需要与 slice 的 exp_shape 一致
- 参数 axis/start/stop/step 不支持 list。因为,多个 axis 的话可能导致 slice 的 shape 错误。
比如,x 为 [8, 8], src 为 [8, 2],则 axis 只能为 1。


# 六、测试和验收的考量

- 覆盖动态图和静态图的测试场景
- 覆盖 CPU、GPU 两种测试场景
- 支持各种Tensor精度,FP32、FP64(带验证
- 需要检查前向和反向计算的精度正确性
- 处理0维输入数据
- 处理可选参数不存在或不一致的情况
- 支持各种Tensor精度,FP32、FP64 等(待验证
- 需要检查计算正确性
- 需要检查多维的情况
- 需要检查 broadcast 情况

# 七、可行性分析和排期规划

Expand Down

0 comments on commit 3b0188b

Please sign in to comment.