Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR】add sequence_mask in pir #59348

Merged
merged 18 commits into from
Dec 8, 2023

Conversation

yangguohao
Copy link
Contributor

PR types

Others

PR changes

Others

Description

add sequence_mask for pir

Copy link

paddle-bot bot commented Nov 24, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Nov 24, 2023
Copy link
Contributor

@kangguangli kangguangli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样改试试,之所以提示API 变化可能是这个原因,虽然实际上对API没影响

paddle/fluid/pir/dialect/operator/ir/ops.yaml Outdated Show resolved Hide resolved
paddle/phi/api/yaml/op_compat.yaml Outdated Show resolved Hide resolved
kangguangli
kangguangli previously approved these changes Dec 4, 2023
args: (Tensor x, Scalar(int) max_len, int out_dtype)
output: Tensor(y)
infer_meta:
func: SequenceMaskPIRInferMeta
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

phi算子库下的算子和具体执行体系无关,后续也并非pir专用的,这里命名不太合理,建议改成SequenceMaskScalarInferMeta

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

infer_meta:
func: SequenceMaskPIRInferMeta
kernel:
func: sequence_mask_pir
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,建议改成sequence_mask_scalar

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 95 to 124
int maxlen = max_len.to<int>();
auto* x_data = x.data<T>();
auto x_numel = x.numel();

if (maxlen < 0) {
if (x_numel == 0) {
maxlen = 0;
} else {
#if defined(__NVCC__) || defined(__HIPCC__)
VLOG(10)
<< "SequenceMaskOp on GPU may be slow when maxlen is not provided.";
maxlen = static_cast<int>(
thrust::reduce(thrust::device_pointer_cast(x_data),
thrust::device_pointer_cast(x_data) + x_numel,
static_cast<T>(0),
thrust::maximum<T>()));
#else
maxlen = static_cast<int>(*std::max_element(x_data, x_data + x_numel));
#endif
}
}

auto y_dim = phi::vectorize<int>(x.dims());
y_dim.push_back(maxlen);
y->Resize(phi::make_ddim(y_dim));

phi::VisitDataType(phi::TransToPhiDataType(out_dtype),
phi::funcs::SequenceMaskFunctor<Context, T>(
ctx, x_data, y, x_numel * maxlen, maxlen));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

既然这里已经将代码拷贝过来了,可以将SequenceMaskKernel中的重复代码删掉,通过传入的max_len_tensor或者maxlen,构造一个Scalar类型的max_len,调用SequenceMaskScalarKernel即可

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

kangguangli
kangguangli previously approved these changes Dec 7, 2023
YuanRisheng
YuanRisheng previously approved these changes Dec 7, 2023
XieYunshen
XieYunshen previously approved these changes Dec 7, 2023
Copy link
Contributor

@XieYunshen XieYunshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
单测超时时间设置

jeff41404
jeff41404 previously approved these changes Dec 7, 2023
jzhang533
jzhang533 previously approved these changes Dec 7, 2023
Copy link
Contributor

@jzhang533 jzhang533 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

zyfncg
zyfncg previously approved these changes Dec 7, 2023
paddle/phi/infermeta/binary.h Outdated Show resolved Hide resolved
heavyrain-lzy
heavyrain-lzy previously approved these changes Dec 7, 2023
Copy link
Contributor

@heavyrain-lzy heavyrain-lzy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for 'op_compat.yaml'

Copy link
Contributor

@jzhang533 jzhang533 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kangguangli kangguangli merged commit 669a300 into PaddlePaddle:develop Dec 8, 2023
28 of 29 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants