Skip to content

Commit

Permalink
Fix speed issue on LUMI with 7B model (#270)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Sep 13, 2023
1 parent d2abecd commit 2eedf07
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
18 changes: 18 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,18 @@ class CompilerConfig(BaseConfig):
"""


class FSDPWrapStrategy(StrEnum):
by_block = "by_block"
"""
Wrap each OLMo block with its own FSDP instance.
"""

size_based = "size_based"
"""
Used PyTorch's default size-based auto wrap policy.
"""


@dataclass
class FSDPConfig(BaseConfig):
use_orig_params: bool = True
Expand All @@ -510,6 +522,12 @@ class FSDPConfig(BaseConfig):

sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD

wrapping_strategy: Optional[FSDPWrapStrategy] = None
"""
The wrapping strategy to use. If ``None``, the default, the model is wrapped with a single top-level
FSDP instance.
"""


class CheckpointType(StrEnum):
sharded = "sharded"
Expand Down
10 changes: 8 additions & 2 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
import wandb
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from torchmetrics import MeanMetric

from olmo.config import CheckpointType, TrainConfig
from olmo.config import CheckpointType, FSDPWrapStrategy, TrainConfig
from olmo.data import build_train_dataloader
from olmo.eval import build_evaluators
from olmo.exceptions import OlmoCliError, OlmoConfigurationError
Expand Down Expand Up @@ -107,6 +108,11 @@ def main(cfg: TrainConfig) -> None:

# Wrap the model in FSDP.
log.info("Wrapping model with FDSP...")
wrap_policy = None
if cfg.fsdp.wrapping_strategy == FSDPWrapStrategy.by_block:
wrap_policy = olmo_model.fsdp_wrap_fn
elif cfg.fsdp.wrapping_strategy == FSDPWrapStrategy.size_based:
wrap_policy = size_based_auto_wrap_policy
fsdp_model = FSDP(
olmo_model,
sharding_strategy=cfg.fsdp.sharding_strategy,
Expand All @@ -115,7 +121,7 @@ def main(cfg: TrainConfig) -> None:
reduce_dtype=cfg.autocast_precision,
buffer_dtype=cfg.autocast_precision,
),
auto_wrap_policy=olmo_model.fsdp_wrap_fn,
auto_wrap_policy=wrap_policy,
use_orig_params=cfg.fsdp.use_orig_params, # needed for compile and some of our optimizer/parameter metrics
limit_all_gathers=True,
device_id=get_local_rank(),
Expand Down

0 comments on commit 2eedf07

Please sign in to comment.