From 5531b9c46226b686992b5386896d8ad05b4cfd3a Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Mon, 6 Nov 2023 14:15:19 -0800 Subject: [PATCH] [Dist] Enable FSDP on CPU (#112145) Differential Revision: [D50688958](https://our.internmc.facebook.com/intern/diff/D50688958/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/112145 Approved by: https://github.com/fegin ghstack dependencies: #112144 --- test/distributed/fsdp/test_fsdp_misc.py | 22 ++++++++++++++++++++++ torch/cpu/__init__.py | 21 ++++++++++++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/test/distributed/fsdp/test_fsdp_misc.py b/test/distributed/fsdp/test_fsdp_misc.py index cdf11174c579b..b409ec5bb3ff1 100644 --- a/test/distributed/fsdp/test_fsdp_misc.py +++ b/test/distributed/fsdp/test_fsdp_misc.py @@ -502,6 +502,28 @@ def test_fsdp_optimizer_overlap(self): (n, p.clone()) for n, p in fsdp_overlap.named_parameters() ] + @skip_if_lt_x_gpu(2) + def test_fsdp_cpu_training(self): + """Tests FSDP training on CPU.""" + torch.manual_seed(0) + gloo_pg = dist.new_group(backend="gloo") + for ss in [ + ShardingStrategy.NO_SHARD, + ShardingStrategy.FULL_SHARD, + ShardingStrategy.SHARD_GRAD_OP, + ShardingStrategy.HYBRID_SHARD, + ShardingStrategy._HYBRID_SHARD_ZERO2, + ]: + model = MyModel() + fsdp = FSDP( + model, + auto_wrap_policy=always_wrap_policy, + process_group=gloo_pg, + device_id=torch.device("cpu"), + ) + inp = torch.randn(2, 2) + fsdp(inp, inp).sum().backward() + @skip_if_lt_x_gpu(2) def test_fsdp_cpu_init_stays_on_cpu(self): # Move me to MT test once warning logging and backward collective issue diff --git a/torch/cpu/__init__.py b/torch/cpu/__init__.py index 33a9c659e12eb..14794627d752b 100644 --- a/torch/cpu/__init__.py +++ b/torch/cpu/__init__.py @@ -21,6 +21,7 @@ "device_count", "Stream", "StreamContext", + "Event", ] _device_t = Union[_device, str, int, None] @@ -56,7 +57,25 @@ class Stream: N.B. This class only exists to facilitate device-agnostic code """ - pass + def __init__(self, priority: int = -1): + pass + + def wait_stream(self, stream) -> None: + pass + + +class Event: + def query(self) -> bool: + return True + + def record(self, stream=None): + pass + + def synchronize(self): + pass + + def wait(self, stream=None): + pass _default_cpu_stream = Stream()