From 6e5f6674b7a6f8d13703b0421c8193c3df2966cc Mon Sep 17 00:00:00 2001 From: Saurabh Shah Date: Tue, 26 Sep 2023 10:25:41 -0700 Subject: [PATCH 1/7] one-line fix for torch 2.1 inits --- scripts/train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/train.py b/scripts/train.py index 311f7b9d0..54f49a060 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -120,6 +120,8 @@ def main(cfg: TrainConfig) -> None: limit_all_gathers=True, device_id=get_local_rank(), ) + # necessary for torch2.1.0 and beyond! Leads to double-init on earlier versions, but that's okay. + olmo_model.reset_parameters() log.info(f"Peak GPU Memory (MB) after FSDP: {int(peak_gpu_memory() or 0)}") From 4f00ae3110c0a8bda088fc23f536039630369976 Mon Sep 17 00:00:00 2001 From: Saurabh Shah Date: Tue, 26 Sep 2023 10:56:56 -0700 Subject: [PATCH 2/7] ensure nothing changes with inits for earlier versions --- scripts/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index 54f49a060..e90bd380c 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -120,8 +120,9 @@ def main(cfg: TrainConfig) -> None: limit_all_gathers=True, device_id=get_local_rank(), ) - # necessary for torch2.1.0 and beyond! Leads to double-init on earlier versions, but that's okay. - olmo_model.reset_parameters() + # necessary for torch 2.1.0 and beyond! + if torch.__version__ >= "2.1.0": + olmo_model.reset_parameters() log.info(f"Peak GPU Memory (MB) after FSDP: {int(peak_gpu_memory() or 0)}") From be590647b30fdf96f8da0795c4a6d8163e8efb9a Mon Sep 17 00:00:00 2001 From: Saurabh Shah Date: Tue, 26 Sep 2023 13:34:00 -0700 Subject: [PATCH 3/7] fixes from Pete's comments --- scripts/train.py | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index e90bd380c..d06e21582 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -7,6 +7,7 @@ from functools import partial from pathlib import Path from typing import Optional, TextIO +from packaging import version import torch import torch.distributed as dist @@ -111,18 +112,31 @@ def main(cfg: TrainConfig) -> None: wrap_policy = olmo_model.fsdp_wrap_fn elif cfg.fsdp.wrapping_strategy == FSDPWrapStrategy.size_based: wrap_policy = size_based_auto_wrap_policy - fsdp_model = FSDP( - olmo_model, - sharding_strategy=cfg.fsdp.sharding_strategy, - mixed_precision=cfg.fsdp_precision, - auto_wrap_policy=wrap_policy, - use_orig_params=cfg.fsdp.use_orig_params, # needed for compile and some of our optimizer/parameter metrics - limit_all_gathers=True, - device_id=get_local_rank(), - ) - # necessary for torch 2.1.0 and beyond! - if torch.__version__ >= "2.1.0": + if version.parse(torch.__version__) >= version.parse("2.1.0"): + # This prevents any parameters from being initialized twice + def dummy_init_fn(module: torch.nn.Module) -> None: + module.to_empty(device=get_local_rank()) + fsdp_model = FSDP( + olmo_model, + sharding_strategy=cfg.fsdp.sharding_strategy, + mixed_precision=cfg.fsdp_precision, + auto_wrap_policy=wrap_policy, + use_orig_params=cfg.fsdp.use_orig_params, # needed for compile and some of our optimizer/parameter metrics + limit_all_gathers=True, + device_id=get_local_rank(), + param_init_fn=dummy_init_fn + ) olmo_model.reset_parameters() + else: + fsdp_model = FSDP( + olmo_model, + sharding_strategy=cfg.fsdp.sharding_strategy, + mixed_precision=cfg.fsdp_precision, + auto_wrap_policy=wrap_policy, + use_orig_params=cfg.fsdp.use_orig_params, # needed for compile and some of our optimizer/parameter metrics + limit_all_gathers=True, + device_id=get_local_rank() + ) log.info(f"Peak GPU Memory (MB) after FSDP: {int(peak_gpu_memory() or 0)}") From 56228a1b65e361f51d6fa0a734d2723f08ca3834 Mon Sep 17 00:00:00 2001 From: Saurabh Shah Date: Tue, 26 Sep 2023 13:41:55 -0700 Subject: [PATCH 4/7] formatting --- scripts/train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index d06e21582..813871f79 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -5,9 +5,9 @@ import os import sys from functools import partial +from packaging import version from pathlib import Path from typing import Optional, TextIO -from packaging import version import torch import torch.distributed as dist @@ -116,6 +116,7 @@ def main(cfg: TrainConfig) -> None: # This prevents any parameters from being initialized twice def dummy_init_fn(module: torch.nn.Module) -> None: module.to_empty(device=get_local_rank()) + fsdp_model = FSDP( olmo_model, sharding_strategy=cfg.fsdp.sharding_strategy, @@ -124,7 +125,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None: use_orig_params=cfg.fsdp.use_orig_params, # needed for compile and some of our optimizer/parameter metrics limit_all_gathers=True, device_id=get_local_rank(), - param_init_fn=dummy_init_fn + param_init_fn=dummy_init_fn, ) olmo_model.reset_parameters() else: @@ -135,7 +136,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None: auto_wrap_policy=wrap_policy, use_orig_params=cfg.fsdp.use_orig_params, # needed for compile and some of our optimizer/parameter metrics limit_all_gathers=True, - device_id=get_local_rank() + device_id=get_local_rank(), ) log.info(f"Peak GPU Memory (MB) after FSDP: {int(peak_gpu_memory() or 0)}") From e56056672ceb8b8dfd1ead26c7c1b6787f9412c2 Mon Sep 17 00:00:00 2001 From: Saurabh Shah Date: Tue, 26 Sep 2023 13:49:48 -0700 Subject: [PATCH 5/7] fixed import --- scripts/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/train.py b/scripts/train.py index 813871f79..81a5510ce 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -5,13 +5,13 @@ import os import sys from functools import partial -from packaging import version from pathlib import Path from typing import Optional, TextIO import torch import torch.distributed as dist import wandb +from packaging import version from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy From a5a88fc7281b9c88265bfad27ad0af05de146766 Mon Sep 17 00:00:00 2001 From: Saurabh Shah Date: Tue, 26 Sep 2023 14:12:01 -0700 Subject: [PATCH 6/7] not double calling FSDP --- scripts/train.py | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index 81a5510ce..7e9fd9687 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -112,32 +112,28 @@ def main(cfg: TrainConfig) -> None: wrap_policy = olmo_model.fsdp_wrap_fn elif cfg.fsdp.wrapping_strategy == FSDPWrapStrategy.size_based: wrap_policy = size_based_auto_wrap_policy + if version.parse(torch.__version__) >= version.parse("2.1.0"): # This prevents any parameters from being initialized twice def dummy_init_fn(module: torch.nn.Module) -> None: module.to_empty(device=get_local_rank()) - fsdp_model = FSDP( - olmo_model, - sharding_strategy=cfg.fsdp.sharding_strategy, - mixed_precision=cfg.fsdp_precision, - auto_wrap_policy=wrap_policy, - use_orig_params=cfg.fsdp.use_orig_params, # needed for compile and some of our optimizer/parameter metrics - limit_all_gathers=True, - device_id=get_local_rank(), - param_init_fn=dummy_init_fn, - ) - olmo_model.reset_parameters() + param_init_fn = dummy_init_fn else: - fsdp_model = FSDP( - olmo_model, - sharding_strategy=cfg.fsdp.sharding_strategy, - mixed_precision=cfg.fsdp_precision, - auto_wrap_policy=wrap_policy, - use_orig_params=cfg.fsdp.use_orig_params, # needed for compile and some of our optimizer/parameter metrics - limit_all_gathers=True, - device_id=get_local_rank(), - ) + param_init_fn = None + + fsdp_model = FSDP( + olmo_model, + sharding_strategy=cfg.fsdp.sharding_strategy, + mixed_precision=cfg.fsdp_precision, + auto_wrap_policy=wrap_policy, + use_orig_params=cfg.fsdp.use_orig_params, # needed for compile and some of our optimizer/parameter metrics + limit_all_gathers=True, + device_id=get_local_rank(), + param_init_fn=param_init_fn, + ) + if param_init_fn is not None: + olmo_model.reset_parameters() log.info(f"Peak GPU Memory (MB) after FSDP: {int(peak_gpu_memory() or 0)}") From 7f5e4e550fb1d31b9c1ffce65f1a6614d490075f Mon Sep 17 00:00:00 2001 From: Saurabh Shah Date: Tue, 26 Sep 2023 14:38:46 -0700 Subject: [PATCH 7/7] added comment --- scripts/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/train.py b/scripts/train.py index 7e9fd9687..fc5332b1f 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -132,6 +132,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None: device_id=get_local_rank(), param_init_fn=param_init_fn, ) + # when param_init_fn is None, FSDP will call reset_parameters() automatically if param_init_fn is not None: olmo_model.reset_parameters()