Skip to content

Commit

Permalink
[FIX DDP] fix ddp (#8549)
Browse files Browse the repository at this point in the history
* enable trainer tests.
  • Loading branch information
ZHUI committed Jun 7, 2024
1 parent 4e3f60d commit f89c91d
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 92 deletions.
13 changes: 2 additions & 11 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,17 +1795,8 @@ def _wrap_model(self, model, training=True):
in_cp_parallel_mode = self.args.context_parallel_degree > 1

# Multi-gpu training
if (
self.args.world_size > 1
and not self.args.use_hybrid_parallel
or not (
in_pipeline_parallel_mode
or in_sharding_parallel_mode
or in_tensor_parallel_mode
or in_sep_parallel_mode
or in_cp_parallel_mode
)
):
if self.args.world_size > 1 and (not self.args.use_hybrid_parallel):
# MOE use DDP to broadcaset parameters.
model = paddle.DataParallel(model)
# Distributed training (should be after fp16 initialization)

Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1529,7 +1529,7 @@ def is_segment_parallel_supported():
if world_size > 1:
if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized():
if self.unified_checkpoint:
self.use_hybrid_parallel = True
# DP use hybrid group
strategy = fleet.DistributedStrategy()
fleet.init(is_collective=True, strategy=strategy)
else:
Expand Down
5 changes: 1 addition & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@ exclude = ['.flake8']
[tool.pytest.ini_options]
minversion = "6.0"
addopts = "-ra -q --dist loadgroup"
retries = 0
retry_delay = 0.5
timeout = 200
pythonpath = ["."]
testpaths = [
"tests/data",
Expand All @@ -25,7 +22,7 @@ testpaths = [
"tests/layers",
"tests/metrics",
"tests/ops",
# "tests/trainer",
"tests/trainer",
"tests/transformers",
"tests/peft",
"tests/prompt",
Expand Down
6 changes: 3 additions & 3 deletions scripts/unit_test/ci_unit.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ install_requirements() {
python -m pip install -r requirements-dev.txt
python -m pip install -r tests/requirements.txt
python -m pip install -r paddlenlp/experimental/autonlp/requirements.txt
python -m pip uninstall paddlepaddle -y
python -m pip uninstall paddlepaddle paddlepaddle_gpu -y
python -m pip install --no-cache-dir ${paddle}

python setup.py bdist_wheel
python setup.py bdist_wheel > /dev/null
python -m pip install dist/p****.whl
cd csrc/
python setup_cuda.py install
Expand All @@ -51,4 +51,4 @@ set_env() {

install_requirements
set_env
pytest -v -n 8 --durations 20
pytest -v -n 8 --timeout 200 --durations 20 --cov paddlenlp --cov-report xml:coverage.xml
47 changes: 25 additions & 22 deletions tests/trainer/test_lora_unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __test__(cls):

def setUp(self):
"""
1. update runfrist and rerun to run defined different config
1. update runfirst and rerun to run defined different config
2. update need_allclose to True if you want to check the result
3. update rtol to the relative value you want to check
"""
Expand All @@ -171,7 +171,7 @@ def setUp(self):

self.run_lora_file = "llm/finetune_generation.py"

def runfrist(self, train_args):
def runfirst(self, train_args):
self.run_n1c8(self.run_lora_file, **train_args)

def rerun(self, train_args):
Expand All @@ -183,7 +183,7 @@ def testTP4PP2(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["TP4PP2"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -198,7 +198,7 @@ def testTP2Sharding4(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["TP2Sharding4"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -216,7 +216,7 @@ def testTP8(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["TP8"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -230,7 +230,7 @@ def testTP4DP2(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["TP4DP2"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -245,7 +245,7 @@ def testTP4Sharding2(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["TP4Sharding2"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -260,7 +260,7 @@ def testTP2PP4(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["TP2PP4"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -275,7 +275,7 @@ def testPP8(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["PP8"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -290,7 +290,7 @@ def testPP4DP2(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["PP4DP2"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -305,7 +305,7 @@ def testPP4Sharding2(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["PP4Sharding2"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -320,7 +320,7 @@ def testSharding8S1(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["Sharding8S1"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -335,7 +335,7 @@ def testSharding8S2(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["Sharding8S2"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -350,7 +350,7 @@ def testSharding4S1DP2(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["Sharding4S1DP2"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -365,7 +365,7 @@ def testSharding4S2DP2(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["Sharding4S2DP2"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -380,7 +380,7 @@ def testSharding2S1DP4(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["Sharding2S1DP4"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -395,7 +395,7 @@ def testSharding2S2DP4(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["Sharding2S2DP4"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -410,7 +410,7 @@ def testDP8(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["DP8"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -419,27 +419,29 @@ def testDP8(self):
np.testing.assert_allclose(res[0], res[1], self.rtol)


@pytest.mark.skipif(True, reason="Skip for None CE")
class TestUnifiedCheckpointOnN2C4(TestUnifiedCheckpointBase):
def setUp(self):
super().setUp()
self.need_allclose = True
self.rtol = 1e-7

def runfrist(self, train_args):
def runfirst(self, train_args):
self.run_n2c4(self.run_lora_file, **train_args)

def rerun(self, train_args):
self.run_n2c4(self.run_lora_file, **train_args)


@pytest.mark.skipif(True, reason="Skip for None CE")
class TestUnifiedCheckpointOnN1C8CheckpointCompatible(TestUnifiedCheckpointBase):
def setUp(self):
super().setUp()

self.need_allclose = True
self.rtol = 1e-7

def runfrist(self, train_args):
def runfirst(self, train_args):
train_args["unified_checkpoint"] = 0
self.run_n1c8(self.run_lora_file, **train_args)

Expand All @@ -448,14 +450,15 @@ def rerun(self, train_args):
self.run_n1c8(self.run_lora_file, **train_args)


@pytest.mark.skipif(True, reason="Skip for None CE")
class TestPaddleCheckpointOnN1C8Reset(TestUnifiedCheckpointBase):
def setUp(self):
super().setUp()

self.need_allclose = True
self.rtol = 1e-7

def runfrist(self, train_args):
def runfirst(self, train_args):
train_args["unified_checkpoint"] = 0
self.run_n1c8(self.run_lora_file, **train_args)

Expand All @@ -472,7 +475,7 @@ def setUp(self):
self.need_allclose = True
self.rtol = 1e-7

def runfrist(self, train_args):
def runfirst(self, train_args):
train_args["unified_checkpoint"] = 0
self.run_n2c4(self.run_lora_file, **train_args)

Expand Down
Loading

0 comments on commit f89c91d

Please sign in to comment.