Skip to content

Commit

Permalink
Merge pull request #280 from allenai/petew/reduce-dtype
Browse files Browse the repository at this point in the history
add support for reducing gradients in fp32
  • Loading branch information
dirkgr committed Sep 21, 2023
2 parents 2a7f694 + 8cefe6a commit 70a3f4c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
35 changes: 34 additions & 1 deletion olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
from omegaconf import OmegaConf as om
from omegaconf.errors import OmegaConfBaseException
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy

from .aliases import PathOrStr
from .exceptions import OlmoConfigurationError
Expand Down Expand Up @@ -513,6 +513,20 @@ class FSDPWrapStrategy(StrEnum):
"""


class FSDPPrecision(StrEnum):
pure = "pure"
"""
Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, ``reduce_dtype``,
and ``buffer_dtype`` all set to the autocast precision data type.
"""

mixed = "mixed"
"""
Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, and ``buffer_dtype``
set to the autocast precision data type, while ``reduce_dtype`` is set to fp32.
"""


@dataclass
class FSDPConfig(BaseConfig):
use_orig_params: bool = True
Expand All @@ -528,6 +542,8 @@ class FSDPConfig(BaseConfig):
FSDP instance.
"""

precision: FSDPPrecision = FSDPPrecision.pure


class CheckpointType(StrEnum):
sharded = "sharded"
Expand Down Expand Up @@ -784,3 +800,20 @@ def autocast_precision(self) -> torch.dtype:
return torch.float32
else:
raise ValueError(f"Unexpected precision type '{self.precision}'")

@property
def fsdp_precision(self) -> MixedPrecision:
if self.fsdp.precision == FSDPPrecision.pure:
return MixedPrecision(
param_dtype=self.autocast_precision,
reduce_dtype=self.autocast_precision,
buffer_dtype=self.autocast_precision,
)
elif self.fsdp.precision == FSDPPrecision.mixed:
return MixedPrecision(
param_dtype=self.autocast_precision,
reduce_dtype=torch.float32,
buffer_dtype=self.autocast_precision,
)
else:
raise NotImplementedError(f"{self.fsdp.precision}")
7 changes: 1 addition & 6 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import torch.distributed as dist
import wandb
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from torchmetrics import MeanMetric

Expand Down Expand Up @@ -116,11 +115,7 @@ def main(cfg: TrainConfig) -> None:
fsdp_model = FSDP(
olmo_model,
sharding_strategy=cfg.fsdp.sharding_strategy,
mixed_precision=MixedPrecision( # equivalent to MosaicML's "PURE"
param_dtype=cfg.autocast_precision,
reduce_dtype=cfg.autocast_precision,
buffer_dtype=cfg.autocast_precision,
),
mixed_precision=cfg.fsdp_precision,
auto_wrap_policy=wrap_policy,
use_orig_params=cfg.fsdp.use_orig_params, # needed for compile and some of our optimizer/parameter metrics
limit_all_gathers=True,
Expand Down

0 comments on commit 70a3f4c

Please sign in to comment.