Skip to content

Commit 76a484d

Browse files
LucasLLCjeffdaily
authored andcommitted
[DCP] Removes Checkpoint Wrapped Prefix from state dict fqns (pytorch#118119)
Fixes pytorch#117399 ~~Soliciting some early feedback here.~~ ~~Do we happen to know if there already some tests that cover this case or would it make sense to add? @fegin , @wz337~~ Edit: Added tests Pull Request resolved: pytorch#118119 Approved by: https://github.com/fegin
1 parent 9afad97 commit 76a484d

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

test/distributed/checkpoint/test_state_dict.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from torch.distributed._composable import fully_shard, replicate
1212
from torch.distributed._shard.sharded_tensor import ShardedTensor
1313
from torch.distributed._tensor import DTensor, init_device_mesh
14+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
15+
apply_activation_checkpointing,
16+
)
1417
from torch.distributed.checkpoint.state_dict import (
1518
_patch_model_state_dict,
1619
_patch_optimizer_state_dict,
@@ -443,6 +446,19 @@ def is_cpu(v):
443446
self.assertEqual(mst, {})
444447
self.assertEqual(ost, {})
445448

449+
@with_comms
450+
@skip_if_lt_x_gpu(1)
451+
def test_activation_ckpt_fqns(self) -> None:
452+
"""Tests that activation checkpointing prefixes are removed from module names"""
453+
model = CompositeParamModel(device=torch.device("cuda"))
454+
original_keys = get_model_state_dict(model).keys()
455+
456+
apply_activation_checkpointing(model)
457+
model = DDP(model)
458+
new_keys = get_model_state_dict(model).keys()
459+
460+
self.assertEqual(original_keys, new_keys)
461+
446462

447463
if __name__ == "__main__":
448464
run_tests()

torch/distributed/checkpoint/state_dict.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
_offload_state_dict_to_cpu,
2727
)
2828
from torch.distributed._tensor import DTensor
29+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
30+
_CHECKPOINT_PREFIX,
31+
)
2932
from torch.distributed.fsdp import (
3033
FullOptimStateDictConfig,
3134
FullStateDictConfig,
@@ -145,7 +148,7 @@ def _get_fqns(model: nn.Module, name: str, skip_ddp_prefix: bool = True) -> FQNS
145148
The canonical FQNs based on the model traversal.
146149
"""
147150
if "." not in name:
148-
return {name}
151+
return {name.replace(_CHECKPOINT_PREFIX, "")}
149152

150153
obj_names = name.split(".")
151154
fqn_obj_names = []
@@ -162,6 +165,8 @@ def _get_fqns(model: nn.Module, name: str, skip_ddp_prefix: bool = True) -> FQNS
162165
flat_param = getattr(curr_obj, FLAT_PARAM)
163166
if prefix:
164167
prefix = f"{prefix}."
168+
# FSDP already handles removal of checkpoint prefix, so we can return
169+
# directly
165170
return {f"{prefix}{fqn}" for fqn in flat_param._fqns}
166171
curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE)
167172
if curr_obj_name != FSDP_WRAPPED_MODULE:
@@ -171,7 +176,7 @@ def _get_fqns(model: nn.Module, name: str, skip_ddp_prefix: bool = True) -> FQNS
171176
fqn_obj_names.append(curr_obj_name)
172177
curr_obj = getattr(curr_obj, curr_obj_name)
173178

174-
return {".".join(fqn_obj_names)}
179+
return {".".join(fqn_obj_names).replace(_CHECKPOINT_PREFIX, "")}
175180

176181

177182
def _verify_options(

0 commit comments

Comments
 (0)