Skip to content

Commit

Permalink
add: fsdp module
Browse files Browse the repository at this point in the history
  • Loading branch information
Secbone committed Apr 10, 2024
1 parent 4952f48 commit 92e94a3
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions toad/nn/distributed/fsdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload,
BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)


class FSDPModule(FSDP):
"""distributed module class
"""
def fit(self, *args, **kwargs):
return self.module.fit(*args, **kwargs)

def save(self, *args, **kwargs):
return self.module.save(*args, **kwargs)

def load(self, *args, **kwargs):
return self.module.load(*args, **kwargs)

def log(self, *args, **kwargs):
return self.module.log(*args, **kwargs)

0 comments on commit 92e94a3

Please sign in to comment.