Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

转换规则 No.243-300 其他未领取项 #217

Merged
merged 4 commits into from
Aug 15, 2023

Conversation

co63oc
Copy link
Contributor

@co63oc co63oc commented Aug 2, 2023

PR Docs

#112

映射文档 PaddlePaddle/docs#6075

PR APIs

243 torch.linalg.inv_ex 功能缺失
246 torch.nn.LazyBatchNorm1d 功能缺失
247 torch.Tensor.sparse_resize_and_clear_ 功能缺失
248 torch.linalg.svdvals
251 torch.linalg.matrix_norm 功能缺失
254 torch.gradient 功能缺失
255 torch.special.gammaln
257 torch.nn.utils.parametrize.is_parametrized 功能缺失
259 torch.special.erfc 组合替代实现
260 torch.linalg.vector_norm 功能缺失
262 torch.nn.functional.group_norm 功能缺失
265 torch.special.expit 组合替代实现
266 torch.nn.Mish 已有文档验证无误
268 torch.linalg.cholesky_ex 功能缺失
269 torch.Tensor.sparse_resize_ 功能缺失
270 torch.positive 功能缺失
275 torch.special.exp2 组合替代实现
276 torch.nn.utils.parametrizations.spectral_norm
277 torch.cuda.StreamContext 功能缺失
278 torch.nn.LazyBatchNorm2d 功能缺失
281 torch.autograd.profiler.profile.self_cpu_time_total 功能缺失
284 torch.profiler.ProfilerActivity 功能缺失
286 torch.special.entr 功能缺失
287 torch.nn.LazyBatchNorm3d 功能缺失
288 torch.autograd.function.FunctionCtx.mark_dirty 功能缺失
289 torch.autograd.function.FunctionCtx.mark_non_differentiable
290 torch.autograd.Function.forward
291 torch.autograd.function.Function
292 torch.autograd.function.FunctionCtx.set_materialize_grads
293 torch.autograd.function.FunctionCtx.save_for_backward
294 torch.autograd.Function.backward
295 torch.profiler.ProfilerAction 功能缺失
296 torch.cuda.set_stream
300 torch.autograd.graph.saved_tensors_hooks

@paddle-bot
Copy link

paddle-bot bot commented Aug 2, 2023

Thanks for your contribution!

@co63oc
Copy link
Contributor Author

co63oc commented Aug 2, 2023

FunctionCtx对象方法不能识别,torch.autograd.function.FunctionCtx.mark_dirty 增加一行代码用来识别
图片

@luotao1 luotao1 added the HappyOpenSource 快乐开源活动issue与PR label Aug 3, 2023
tests/test_autograd_Function_backward.py Outdated Show resolved Hide resolved
tests/test_autograd_Function_backward.py Outdated Show resolved Hide resolved
tests/test_autograd_function_FunctionCtx_mark_dirty.py Outdated Show resolved Hide resolved
return x+x+x, x+x

@staticmethod
def backward(ctx, grad, grad2):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

图片

saved_tensors 没有合适转换方式,ctx不能识别

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不能识别应该也会原封不动,这里如果函数名一致是不是也不会出错

Copy link
Contributor Author

@co63oc co63oc Aug 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle里是saved_tensor(),不是属性,不转换调用会有错误
图片

tests/test_cuda_StreamContext.py Show resolved Hide resolved
tests/test_gradient.py Show resolved Hide resolved
@zhwesky2010
Copy link
Collaborator

@co63oc 另外还需要处理下冲突

@co63oc
Copy link
Contributor Author

co63oc commented Aug 3, 2023

@co63oc 另外还需要处理下冲突

已修改

Copy link
Collaborator

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后面写新增Matcher时可以按这个标准来哈:

Uploading infoflow 2023-08-03 12-31-08.png…

else:
API_TEMPLATE = textwrap.dedent(
"""
_, {}, _ = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接写{}[1] 可以吧,省一行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果转换后代码末尾为[]符号,测试框架会不转换api ,所以这里增加一行取数组元素,然后赋值

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在 paconvert/transformer/basic_transformer.py:337 这里加一下 ast.Subscript 就可以

return x+x+x, x+x

@staticmethod
def backward(ctx, grad, grad2):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不能识别应该也会原封不动,这里如果函数名一致是不是也不会出错

"""
import torch
x = torch.tensor([[12., -51, 4], [6, 167, -68], [-4, 24, -41]])
result = torch.linalg.svdvals(x, driver=None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

driver是属于 暂无转写方式 还是 可直接删除

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

driver是可直接删除,已修改文档

check_dtype=True,
check_stop_gradient=True,
):
assert isinstance(paddle_result, paddle.nn.Linear)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个映射是计算结果对不上吗?那就可能映射的不对了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

图片

返回类型不是Tensor类型,所以用isinstance判断

@co63oc
Copy link
Contributor Author

co63oc commented Aug 4, 2023

后面写新增Matcher时可以按这个标准来哈:

好的

else:
API_TEMPLATE = textwrap.dedent(
"""
_, {}, _ = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方改成 {}[1],在 paconvert/transformer/basic_transformer.py:337 这里加一下 ast.Subscript 就可以通过了,可以省去一行

Copy link
Collaborator

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhwesky2010 zhwesky2010 merged commit 6c56024 into PaddlePaddle:master Aug 15, 2023
@co63oc co63oc deleted the api255 branch August 19, 2023 00:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers HappyOpenSource 快乐开源活动issue与PR status: proposed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants