-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Weight tying + FSDP = out of bounds #257
Comments
I don't think this is an nvFuser issue. The nvFuser standalone repro does not fail. I wonder if it was just the place that the CUDA error first got caught. On an H100, I am seeing a different error with NCCL.
|
You are correct Kevin. This is not an nvfuser issue. The code was also using some removed arguments. I updated the description |
There are 2 problems at play here:
diff --git a/thunder/distributed/__init__.py b/thunder/distributed/__init__.py
index c9aa00a..5ae1554 100644
--- a/thunder/distributed/__init__.py
+++ b/thunder/distributed/__init__.py
@@ -13,6 +13,7 @@ from functools import partial
import torch
import torch.distributed as tdist
+from torch.utils.weak import WeakTensorKeyDictionary
import thunder.core.utils as utils
from thunder.core.proxies import DDPType
@@ -559,6 +560,9 @@ def _shard_params(
local_rank = int(os.environ["LOCAL_RANK"])
device = torch.device("cuda", local_rank)
+ # In case there is weight/param sharing, we don't want to shard the same param
+ # multiple times. We use `sharded_params` to keep track of already sharded param to avoid resharding it.
+ sharded_params = WeakTensorKeyDictionary()
# We will definitely change the sharding logic in the future
for module_name, submodule in module.named_modules():
# Materialize meta-parameters on-device if necessary.
@@ -581,7 +585,10 @@ def _shard_params(
# Note [FSDP Sharding]
# All internal code will assume that the parameters are sharded on the first dimension
for param_name, param in submodule.named_parameters(recurse=False, prefix=module_name):
+ if param in sharded_params:
+ continue
_shard_param(param, global_rank, world_size, param_name, allow_padding_for_fsdp=allow_padding_for_fsdp)
+ sharded_params[param] = True
def _shard_param( NOTE: lightning-thunder/thunder/distributed/__init__.py Lines 447 to 456 in 7d6e540
# idx: "cuda:0 i64[128, 256]"
# tos1: "cuda:0 f32[256, 24]"
# t_lm_head_weight: "cuda:0 f32[16000, 144]"
p2 = torch_all_gather_prim_impl(t_lm_head_weight, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p2: "FUTURE cuda:0 f32[32000, 144]"
p20 = torch_all_gather_prim_impl(t_transformer_wte_weight, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p20: "FUTURE cuda:0 f32[32000, 144]" where lightning-thunder/thunder/executors/torchex.py Lines 1753 to 1770 in 7d6e540
To tackle 2, I think we need to add some notion of aliasing. Related to inplace support #145 which also has to consider aliasing. |
Could you please submit your fix for 1? It's a perfect solution to this problem. For 2 I think Thunder JIT could recognize these situations and pass just one tensor to the computational trace. |
🐛 Bug
To Reproduce
Code:
Run with:
Error:
Removing one of:
makes the problem not appear
cc @carmocca @awaelchli @crcrpar
The text was updated successfully, but these errors were encountered: