Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 35 additions & 14 deletions src/opentau/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,35 @@ def update_policy(
return train_metrics


def _sync_deepspeed_gradient_accumulation_steps(
accelerator: accelerate.Accelerator, cfg: TrainPipelineConfig
) -> None:
"""Make TrainPipelineConfig the single source of truth for gradient_accumulation_steps.

When DeepSpeed is the distributed backend, the value declared in the Accelerate YAML's
`deepspeed_config.gradient_accumulation_steps` is forcibly overridden to match
`cfg.gradient_accumulation_steps`. Must be called on all ranks and before
`accelerator.prepare(...)`, because `_prepare_deepspeed` reads this value from
`hf_ds_config.config` on every rank at prepare time.
"""
if accelerator.distributed_type != accelerate.DistributedType.DEEPSPEED:
return

ds_config = accelerator.deepspeed_plugin.hf_ds_config.config
current = ds_config.get("gradient_accumulation_steps", 1)
target = cfg.gradient_accumulation_steps
if current != target and accelerator.is_main_process:
logging.warning(
"Overriding DeepSpeed `gradient_accumulation_steps` (%s) with the value from "
"TrainPipelineConfig (%s). TrainPipelineConfig is the single source of truth; "
"the value in the Accelerate YAML is ignored.",
current,
target,
)
ds_config["gradient_accumulation_steps"] = target
accelerator.deepspeed_plugin.gradient_accumulation_steps = target


@parser.wrap()
def train(cfg: TrainPipelineConfig):
cfg.validate()
Expand All @@ -126,35 +155,27 @@ def train(cfg: TrainPipelineConfig):
"step_scheduler_with_optimizer": False,
"split_batches": False, # split_batches == True is not working anyways
"kwargs_handlers": [DistributedDataParallelKwargs(find_unused_parameters=True)],
"gradient_accumulation_steps": cfg.gradient_accumulation_steps,
}
if cfg.wandb.enable:
accelerator_kwargs["log_with"] = "wandb"
if cfg.gradient_accumulation_steps > 1:
accelerator_kwargs["gradient_accumulation_steps"] = cfg.gradient_accumulation_steps

accelerator = accelerate.Accelerator(**accelerator_kwargs)
init_logging(accelerator, level=logging.DEBUG if cfg.debug else logging.INFO)
# Register accelerator globally for use in other modules, (e.g., detect current rank, etc.)
set_proc_accelerator(accelerator)

# Must run before `encode_accelerator_state_dict` + `init_trackers` below so the
# wandb-logged accelerator config and the value DeepSpeed consumes at prepare()
# time both reflect TrainPipelineConfig.
_sync_deepspeed_gradient_accumulation_steps(accelerator, cfg)

logging.info(pformat(cfg.to_dict()))

if accelerator.is_main_process:
accelerator_config = encode_accelerator_state_dict(accelerator.state.__dict__)
logging.info(pformat(accelerator_config))

# Ensure `gradient_accumulation_steps` is consistent between TrainPipelineConfig and DeepSpeedConfig
if accelerator.distributed_type == accelerate.DistributedType.DEEPSPEED:
deepspeed_config, deepspeed_key = accelerator.deepspeed_plugin.hf_ds_config.find_config_node(
"gradient_accumulation_steps"
)
ds_grad_acc_steps = deepspeed_config.get(deepspeed_key, 1)
if ds_grad_acc_steps != cfg.gradient_accumulation_steps:
raise ValueError(
"The `gradient_accumulation_steps` in TrainPipelineConfig does not match the value "
f"specified in DeepSpeedConfig {cfg.gradient_accumulation_steps} != {ds_grad_acc_steps}. " # nosec B608
)

if cfg.wandb.enable:
step = load_training_step(cfg.checkpoint_path) if cfg.resume else None
slurm_dict = {k: v for k, v in os.environ.items() if k.startswith("SLURM_")}
Expand Down
92 changes: 92 additions & 0 deletions tests/scripts/test_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright 2026 Tensor Auto Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import logging
from types import SimpleNamespace

import accelerate

from opentau.scripts.train import _sync_deepspeed_gradient_accumulation_steps


def _make_accelerator(distributed_type, ds_grad_acc, is_main_process=True):
plugin = SimpleNamespace(
hf_ds_config=SimpleNamespace(config={"gradient_accumulation_steps": ds_grad_acc}),
gradient_accumulation_steps=ds_grad_acc,
)
return SimpleNamespace(
distributed_type=distributed_type,
is_main_process=is_main_process,
deepspeed_plugin=plugin,
)


def _make_cfg(grad_acc):
return SimpleNamespace(gradient_accumulation_steps=grad_acc)


def test_non_deepspeed_is_noop(caplog):
accelerator = _make_accelerator(accelerate.DistributedType.MULTI_GPU, ds_grad_acc=2)
cfg = _make_cfg(grad_acc=4)

with caplog.at_level(logging.WARNING):
_sync_deepspeed_gradient_accumulation_steps(accelerator, cfg)

assert accelerator.deepspeed_plugin.hf_ds_config.config["gradient_accumulation_steps"] == 2
assert accelerator.deepspeed_plugin.gradient_accumulation_steps == 2
assert not caplog.records


def test_deepspeed_matching_value_no_warning(caplog):
accelerator = _make_accelerator(accelerate.DistributedType.DEEPSPEED, ds_grad_acc=2)
cfg = _make_cfg(grad_acc=2)

with caplog.at_level(logging.WARNING):
_sync_deepspeed_gradient_accumulation_steps(accelerator, cfg)

assert accelerator.deepspeed_plugin.hf_ds_config.config["gradient_accumulation_steps"] == 2
assert accelerator.deepspeed_plugin.gradient_accumulation_steps == 2
assert not [r for r in caplog.records if r.levelno >= logging.WARNING]


def test_deepspeed_mismatch_overrides_and_warns_on_main(caplog):
accelerator = _make_accelerator(accelerate.DistributedType.DEEPSPEED, ds_grad_acc=2, is_main_process=True)
cfg = _make_cfg(grad_acc=4)

with caplog.at_level(logging.WARNING):
_sync_deepspeed_gradient_accumulation_steps(accelerator, cfg)

assert accelerator.deepspeed_plugin.hf_ds_config.config["gradient_accumulation_steps"] == 4
assert accelerator.deepspeed_plugin.gradient_accumulation_steps == 4
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
assert len(warnings) == 1
message = warnings[0].getMessage()
assert "2" in message
assert "4" in message


def test_deepspeed_mismatch_non_main_overrides_without_warning(caplog):
accelerator = _make_accelerator(
accelerate.DistributedType.DEEPSPEED, ds_grad_acc=2, is_main_process=False
)
cfg = _make_cfg(grad_acc=4)

with caplog.at_level(logging.WARNING):
_sync_deepspeed_gradient_accumulation_steps(accelerator, cfg)

assert accelerator.deepspeed_plugin.hf_ds_config.config["gradient_accumulation_steps"] == 4
assert accelerator.deepspeed_plugin.gradient_accumulation_steps == 4
assert not [r for r in caplog.records if r.levelno >= logging.WARNING]
Loading