[PyTorch] Fix fuse_wgrad_accumulation for GroupedLinear#1488
Conversation
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
|
/te-ci pytorch |
Signed-off-by: Xin Yao <xiny@nvidia.com>
|
/te-ci pytorch |
|
|
||
| ctx.weights_requires_grad = weights[0].requires_grad | ||
| if fuse_wgrad_accumulation and ctx.weights_requires_grad: | ||
| ctx.main_grads = [weights[i].main_grad for i in range(num_gemms)] |
There was a problem hiding this comment.
It's recommended to use ctx.save_for_backward instead of storing tensors directly in ctx. They warn about messing up the grad graph and memory leaks, although I'm not sure what cases they are specifically worried about.
There was a problem hiding this comment.
I agree. We were saving main_grad tensors using ctx.save_for_backward in TE 1.x. But I'm seeing there is comment here.
I'm wondering if we have seen issues with
ctx.save_for_backward?
There was a problem hiding this comment.
Interesting, we should follow the example of Linear then.
@ksivaman This change is from commit 7e58678 in the internal repo. Do you remember why we can't store main_grad in saved_tensors?
There was a problem hiding this comment.
Oh I know, previous prepare_for_saving saves tensor.data instead of the tensor itself. So for main_grad that need to be modified inplace, this could be an issue.
Now #1474 changed prepare_for_saving to save the tensor itself, this is no longer a problem.
| outputs.append(p.grad) | ||
| if getattr(p, "main_grad", None) is not None: | ||
| outputs.append(p.main_grad) | ||
| assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True |
There was a problem hiding this comment.
It turns out Mcore expects p.grad to not be None: #1474 (comment)
#1474 sets grad to an uninitialized tensor and assumes Mcore will ignore it.
|
/te-ci pytorch |
* fix fuse_wgrad_accumulation for GroupedLinear Signed-off-by: Xin Yao <xiny@nvidia.com> * fix fuse_wgrad_accumulation for GroupedLinear Signed-off-by: Xin Yao <xiny@nvidia.com> * update tests Signed-off-by: Xin Yao <xiny@nvidia.com> --------- Signed-off-by: Xin Yao <xiny@nvidia.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Description
Due to the wrong indent, the wgrad computation is not called when
ctx.fuse_wgrad_accumulation == True.Also update the test to cover this case.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: