Skip to content

新增转换规则#583

Merged
zhwesky2010 merged 10 commits intoPaddlePaddle:masterfrom
Xuxuanang:Xuxuanang3
May 29, 2025
Merged

新增转换规则#583
zhwesky2010 merged 10 commits intoPaddlePaddle:masterfrom
Xuxuanang:Xuxuanang3

Conversation

@Xuxuanang
Copy link
Contributor

@Xuxuanang Xuxuanang commented May 17, 2025

PR Docs

PaddlePaddle/docs#7306

PR APIs

torch.distributed.all_gather_object
torch.distributed.all_to_all_single
torch.distributed.reduce_scatter_tensor
torch._foreach_round
torch._foreach_round_
torch._foreach_sin
torch._foreach_sin_
torch._foreach_sinh
torch._foreach_sinh_
torch.utils.set_module

@paddle-bot
Copy link

paddle-bot bot commented May 17, 2025

Thanks for your contribution!

@paddle-bot paddle-bot bot added the contributor External developers label May 17, 2025
Copy link
Collaborator

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文档需要修改,注意与文档思路保持对齐一致

CODE_TEMPLATE = textwrap.dedent(
"""
def reduce_scatter_tensor(output, input, op, group, async_op):
input_list = [input[i] for i in range(input.shape[0])]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个的切分方式需要根据world_size来,看下文档的comment

print(f"Rank {rank} output tensor: {output_tensor}")


if rank == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个单测挂了

rank = dist.get_rank()
torch.cuda.set_device(rank)

input = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32).cuda()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个其实是按world_size切分,不是按shape[0]切分,但这里shape[0] == world_size,导致绕过了这个问题

所以设置input时,设置shape[0] != world_size吧,这样能测试的更充分。另外torch的shape可能有两种情况,都测试一下

infoflow 2025-05-26 11-34-07


print(f"Rank {rank} output tensor: {output_tensor}")

if rank == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个如果要保存两次,两个结果都能测到吗,测不到的话 可能需要写两个单测

torch.cuda.set_device(rank)

input = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32).cuda()
output = torch.empty(1, 2, dtype=torch.float32).cuda()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,两个结果都能测到吗,测不到可能需要写两个文件

Copy link
Collaborator

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhwesky2010 zhwesky2010 merged commit 272505c into PaddlePaddle:master May 29, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants