Skip to content

Commit

Permalink
stash: fsdp method
Browse files Browse the repository at this point in the history
  • Loading branch information
Secbone committed May 13, 2024
1 parent 4c5aa0f commit 982648b
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions toad/nn/distributed/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def strategy(self):
@property
def device(self):
import torch
if self.strategy.device.type == 'cpu':
return torch.device('cpu')

return torch.device(f"cuda:{self.rank}")

def setup(self):
Expand All @@ -61,7 +64,8 @@ def setup(self):
init_method = master_url,
)

torch.cuda.set_device(self.device)
if self.device.type == 'cuda':
torch.cuda.set_device(self.device)


def prepare(self, module, loader, optimizer):
Expand All @@ -78,19 +82,18 @@ def prepare_module(self, module):

if isinstance(self.strategy, FSDPStrategy):
from ..fsdp import FSDP
from torch.distributed.fsdp import CPUOffload

module.to(self.strategy.device)
module = FSDP(
module,
auto_wrap_policy = self.strategy.policy,
device_id = self.device,
cpu_offload = CPUOffload(offload_params = True) if self.device.type == 'cuda' else None,
)

elif isinstance(self.strategy, DDPStrategy):
from ..ddp import DDP
module = DDP(module)

module.to(self.device)

return module
# return ModuleMixin.mixin(module)
Expand Down

0 comments on commit 982648b

Please sign in to comment.