From 0b5c0a10b935a22afc0d8ee0f0e7385f38213e5a Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 3 Nov 2023 13:08:51 -0700 Subject: [PATCH] [FSDP][optim_state_dict] Skip the parameter if the parameter does not belong to the current FSDP instance (#112804) Skip the fsdp managed parameter if the parameter is not managed by the current FSDP instance. This can happen if the not all FSDP instances have all the parameters. This can happen with FSDP + some MPMD style parallelism. Differential Revision: [D50562170](https://our.internmc.facebook.com/intern/diff/D50562170/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/112804 Approved by: https://github.com/wz337 --- torch/distributed/fsdp/_optim_utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 0117aa73a08cf..60eb666b0a1c8 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -1721,7 +1721,15 @@ def _convert_state_with_orig_params( if optim_state_key.is_fsdp_managed: fqn = optim_state_key.unflat_param_names[0] - fsdp_param_info = fqn_to_fsdp_param_info[fqn] + fsdp_param_info = fqn_to_fsdp_param_info.get(fqn, None) + if fsdp_param_info is None: + # This can happen if the not all FSDP instances have all the + # parameters. This can happen with FSDP + some MPMD style + # parallelism. + + # TODO: it is unclear if we need to do the same check with + # non-FSDP managed keys. + continue state = {} if param_key is None else optim_state_dict[param_key] if id(fsdp_param_info) not in all_states: all_states[id(fsdp_param_info)] = {}