-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[compiled autograd] Better cache miss logging #126602
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126602
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f1bb00a with merge base 91bf952 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 9dc7b140c8f9c07f49d4632ebcfc4f69d1e6125e Pull Request resolved: #126602
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: 6169b7700d9a4dc038272709da04221a79cb9379 Pull Request resolved: #126602
- log only first node key cache miss - log existing node key sizes - log which node's collected sizes became dynamic e.g. ``` DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[] ... DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::AccumulateGrad (NodeCall 5) with key size 32, previous key sizes=[21] ... DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to dynamic shapes: collected size idx 0 of torch::autograd::GraphRoot (NodeCall 0) DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to dynamic shapes: collected size idx 2 of SumBackward0 (NodeCall 1) DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to dynamic shapes: collected size idx 4 of SumBackward0 (NodeCall 1) DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to dynamic shapes: collected size idx 2 of ReluBackward0 (NodeCall 2) DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to dynamic shapes: collected size idx 9 of AddmmBackward0 (NodeCall 3) DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to dynamic shapes: collected size idx 2 of torch::autograd::AccumulateGrad (NodeCall 5) DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to dynamic shapes: collected size idx 2 of ReluBackward0 (NodeCall 6) ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: c1d9ac86551adbae994990415adb92bf79c79798 Pull Request resolved: #126602
- log only first node key cache miss - log existing node key sizes - log which node's collected sizes became dynamic e.g. ``` DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[] ... DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::AccumulateGrad (NodeCall 5) with key size 32, previous key sizes=[21] ... DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to dynamic shapes: collected size idx 0 of torch::autograd::GraphRoot (NodeCall 0) DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to dynamic shapes: collected size idx 2 of SumBackward0 (NodeCall 1) DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to dynamic shapes: collected size idx 4 of SumBackward0 (NodeCall 1) DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to dynamic shapes: collected size idx 2 of ReluBackward0 (NodeCall 2) DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to dynamic shapes: collected size idx 9 of AddmmBackward0 (NodeCall 3) DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to dynamic shapes: collected size idx 2 of torch::autograd::AccumulateGrad (NodeCall 5) DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to dynamic shapes: collected size idx 2 of ReluBackward0 (NodeCall 6) ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: b9a7699b8fee1018caf18ef7b17ed5e19181f304 Pull Request resolved: #126602
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
e.g.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang