Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

转换规则 No. 302 #236

Merged
merged 2 commits into from
Aug 28, 2023
Merged

转换规则 No. 302 #236

merged 2 commits into from
Aug 28, 2023

Conversation

co63oc
Copy link
Collaborator

@co63oc co63oc commented Aug 22, 2023

PR Docs

#112

302 torch.sparse_csr_tensor 已有文档验证无误

PR APIs

@paddle-bot
Copy link

paddle-bot bot commented Aug 22, 2023

Thanks for your contribution!

@paddle-bot paddle-bot bot added contributor External developers status: proposed labels Aug 22, 2023
@@ -4057,6 +4057,58 @@ def get_paddle_class_nodes(self, func, args, kwargs):
).body


class SparseCsrMatcher(BaseMatcher):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个不可以直接复用GenericMatcher吗

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

复用 GenericMatcher有错误 ,dtype参数会替换为astype()格式,csr tensor不支持astype调用

图片

@luotao1 luotao1 added the HappyOpenSource 快乐开源活动issue与PR label Aug 22, 2023
Copy link
Collaborator

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

稀疏Tensor支持了cast API,cast 是通用的,你可以把GenericMatcher中的这段逻辑改成:

if dtype_v:
    res += ".astype({})".format(dtype_v)
if dtype_v:
    res += ".cast({})".format(dtype_v)

如果你想避免GenericMatcher处理某个参数,可以使用:

kwargs: ""

将其映射到空字符串

@co63oc
Copy link
Collaborator Author

co63oc commented Aug 23, 2023

sparse tensor 调用cast显示错误“Tensor holds no memory. Call Tensor::mutable_data firstly”
图片

代码

import paddle
x = paddle.rand((3, 4))
csr = x.to_sparse_csr()
csr = csr.cast(paddle.float32)
print(csr)

@zhwesky2010
Copy link
Collaborator

zhwesky2010 commented Aug 25, 2023

import paddle
x = paddle.rand((3, 4))
csr = x.to_sparse_csr()

可以使用:

kwargs: {
  "dtype": "dtype"
}

来配置,这样就不会使用默认的dtype处理方式

@co63oc
Copy link
Collaborator Author

co63oc commented Aug 25, 2023

来配置,这样就不会使用默认的dtype处理方式

已修改

Copy link
Collaborator

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhwesky2010 zhwesky2010 merged commit 632301d into PaddlePaddle:master Aug 28, 2023
8 checks passed
@co63oc co63oc deleted the api302 branch August 30, 2023 23:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers HappyOpenSource 快乐开源活动issue与PR status: proposed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants