Skip to content

Commit

Permalink
[FSDP][optim_state_dict] Skip the parameter if the parameter does not…
Browse files Browse the repository at this point in the history
… belong to the current FSDP instance (pytorch#112804)

Skip the fsdp managed parameter if the parameter is not managed by the current FSDP instance. This can happen if the not all FSDP instances have all the parameters. This can happen with FSDP + some MPMD style parallelism.

Differential Revision: [D50562170](https://our.internmc.facebook.com/intern/diff/D50562170/)
Pull Request resolved: pytorch#112804
Approved by: https://github.com/wz337
  • Loading branch information
fegin authored and Skylion007 committed Nov 14, 2023
1 parent 77e3b5b commit 0b5c0a1
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion torch/distributed/fsdp/_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1721,7 +1721,15 @@ def _convert_state_with_orig_params(

if optim_state_key.is_fsdp_managed:
fqn = optim_state_key.unflat_param_names[0]
fsdp_param_info = fqn_to_fsdp_param_info[fqn]
fsdp_param_info = fqn_to_fsdp_param_info.get(fqn, None)
if fsdp_param_info is None:
# This can happen if the not all FSDP instances have all the
# parameters. This can happen with FSDP + some MPMD style
# parallelism.

# TODO: it is unclear if we need to do the same check with
# non-FSDP managed keys.
continue
state = {} if param_key is None else optim_state_dict[param_key]
if id(fsdp_param_info) not in all_states:
all_states[id(fsdp_param_info)] = {}
Expand Down

0 comments on commit 0b5c0a1

Please sign in to comment.