Skip to content

Commit

Permalink
Merge pull request #290 from allenai/torch2.1init
Browse files Browse the repository at this point in the history
fix for torch 2.1 inits
  • Loading branch information
saurabh111233212 committed Sep 26, 2023
2 parents aec449c + 7f5e4e5 commit 012e97f
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
import torch.distributed as dist
import wandb
from packaging import version
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy

Expand Down Expand Up @@ -111,6 +112,16 @@ def main(cfg: TrainConfig) -> None:
wrap_policy = olmo_model.fsdp_wrap_fn
elif cfg.fsdp.wrapping_strategy == FSDPWrapStrategy.size_based:
wrap_policy = size_based_auto_wrap_policy

if version.parse(torch.__version__) >= version.parse("2.1.0"):
# This prevents any parameters from being initialized twice
def dummy_init_fn(module: torch.nn.Module) -> None:
module.to_empty(device=get_local_rank())

param_init_fn = dummy_init_fn
else:
param_init_fn = None

fsdp_model = FSDP(
olmo_model,
sharding_strategy=cfg.fsdp.sharding_strategy,
Expand All @@ -119,7 +130,11 @@ def main(cfg: TrainConfig) -> None:
use_orig_params=cfg.fsdp.use_orig_params, # needed for compile and some of our optimizer/parameter metrics
limit_all_gathers=True,
device_id=get_local_rank(),
param_init_fn=param_init_fn,
)
# when param_init_fn is None, FSDP will call reset_parameters() automatically
if param_init_fn is not None:
olmo_model.reset_parameters()

log.info(f"Peak GPU Memory (MB) after FSDP: {int(peak_gpu_memory() or 0)}")

Expand Down

0 comments on commit 012e97f

Please sign in to comment.