Skip to content

[Primitive][fork_rng] Do not replace module#76

Merged
comaniac merged 2 commits intoawslabs:mainfrom
comaniac:rng
Mar 8, 2023
Merged

[Primitive][fork_rng] Do not replace module#76
comaniac merged 2 commits intoawslabs:mainfrom
comaniac:rng

Conversation

@comaniac
Copy link
Contributor

@comaniac comaniac commented Mar 8, 2023

Description

This PR changes the way of implementing .fork_rng(). Now it replaces the forward function instead of an entire module. However, this approach results in incorrect trace results. Specifically, this module can still be traced, but the traced graph is based on the original forward function, which is incorrect. As a result, this PR also introduces an internal attribute .traceable to nn.Module, and we check this attribute to determine whether a Slapo customized module is traceable.

Checklist

  • PR's title starts with a category (e.g. [Bugfix], [Model], [Tutorial], etc)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented

cc @chhzh123

Copy link
Contributor

@chhzh123 chhzh123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@comaniac comaniac merged commit 823796b into awslabs:main Mar 8, 2023
@comaniac
Copy link
Contributor Author

comaniac commented Mar 8, 2023

Thanks @chhzh123

@comaniac comaniac deleted the rng branch March 8, 2023 01:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants