-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PIR】add sequence_mask in pir #59348
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这样改试试,之所以提示API 变化可能是这个原因,虽然实际上对API没影响
args: (Tensor x, Scalar(int) max_len, int out_dtype) | ||
output: Tensor(y) | ||
infer_meta: | ||
func: SequenceMaskPIRInferMeta |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
phi算子库下的算子和具体执行体系无关,后续也并非pir专用的,这里命名不太合理,建议改成SequenceMaskScalarInferMeta
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
infer_meta: | ||
func: SequenceMaskPIRInferMeta | ||
kernel: | ||
func: sequence_mask_pir |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,建议改成sequence_mask_scalar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
int maxlen = max_len.to<int>(); | ||
auto* x_data = x.data<T>(); | ||
auto x_numel = x.numel(); | ||
|
||
if (maxlen < 0) { | ||
if (x_numel == 0) { | ||
maxlen = 0; | ||
} else { | ||
#if defined(__NVCC__) || defined(__HIPCC__) | ||
VLOG(10) | ||
<< "SequenceMaskOp on GPU may be slow when maxlen is not provided."; | ||
maxlen = static_cast<int>( | ||
thrust::reduce(thrust::device_pointer_cast(x_data), | ||
thrust::device_pointer_cast(x_data) + x_numel, | ||
static_cast<T>(0), | ||
thrust::maximum<T>())); | ||
#else | ||
maxlen = static_cast<int>(*std::max_element(x_data, x_data + x_numel)); | ||
#endif | ||
} | ||
} | ||
|
||
auto y_dim = phi::vectorize<int>(x.dims()); | ||
y_dim.push_back(maxlen); | ||
y->Resize(phi::make_ddim(y_dim)); | ||
|
||
phi::VisitDataType(phi::TransToPhiDataType(out_dtype), | ||
phi::funcs::SequenceMaskFunctor<Context, T>( | ||
ctx, x_data, y, x_numel * maxlen, maxlen)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
既然这里已经将代码拷贝过来了,可以将SequenceMaskKernel中的重复代码删掉,通过传入的max_len_tensor或者maxlen,构造一个Scalar类型的max_len,调用SequenceMaskScalarKernel即可
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
cd19ed2
to
c19fa9b
Compare
c19fa9b
to
a3e7e2c
Compare
f781324
to
0d378ab
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
单测超时时间设置
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for 'op_compat.yaml'
b6300af
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
Others
Description
add sequence_mask for pir