From f94a915a0bf99129699cfb53f1bec464e509fa1d Mon Sep 17 00:00:00 2001 From: zhangyuqin1998 Date: Sun, 19 May 2024 08:19:37 +0000 Subject: [PATCH] fix tests --- scripts/regression/ci_case.sh | 11 +++++++++++ scripts/regression/run_ci.sh | 2 +- tests/transformers/test_ring_flash_attention.py | 8 +++----- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/scripts/regression/ci_case.sh b/scripts/regression/ci_case.sh index 32cfec4b59de..e19a42f8a756 100644 --- a/scripts/regression/ci_case.sh +++ b/scripts/regression/ci_case.sh @@ -1111,5 +1111,16 @@ else echo "only one gpu:${cudaid1} is set, skip test" fi +} +ring_flash_attention(){ +cd ${nlp_dir} +echo "test ring_flash_attention, cudaid1:${cudaid1}, cudaid2:${cudaid2}" +if [[ ${cudaid1} != ${cudaid2} ]]; then + time (python -m paddle.distributed.launch tests/transformers/test_ring_flash_attention.py >${log_path}/ring_flash_attention) >>${log_path}/ring_flash_attention 2>&1 + print_info $? ring_flash_attention +else + echo "only one gpu:${cudaid1} is set, skip test" +fi + } $1 diff --git a/scripts/regression/run_ci.sh b/scripts/regression/run_ci.sh index 74d0b1957af8..0f7f6fdf5ab0 100644 --- a/scripts/regression/run_ci.sh +++ b/scripts/regression/run_ci.sh @@ -33,7 +33,7 @@ all_P0case_dic=(["waybill_ie"]=3 ["msra_ner"]=15 ["glue"]=2 ["bert"]=2 ["skep"]= ["ernie-ctm"]=5 ["distilbert"]=5 ["transformer"]=5 ["pet"]=5 ["efl"]=5 ["p-tuning"]=5 ["ernie-doc"]=20 ["transformer-xl"]=5 \ ["question_matching"]=5 ["ernie-csc"]=5 ["nptag"]=5 ["ernie-m"]=5 ["taskflow"]=5 ["clue"]=5 ["textcnn"]=5 \ ["fast_generation"]=10 ["ernie-3.0"]=5 ["ernie-layout"]=5 ["uie"]=5 ["ernie-health"]=5 ["llm"]=5 \ -["ernie"]=2 ["ernie_m"]=5 ["ernie_layout"]=5 ["ernie_csc"]=5 ["ernie_ctm"]=5 ["ernie_doc"]=20 ["ernie_health"]=5 ["segment_parallel_utils"]=5) +["ernie"]=2 ["ernie_m"]=5 ["ernie_layout"]=5 ["ernie_csc"]=5 ["ernie_ctm"]=5 ["ernie_doc"]=20 ["ernie_health"]=5 ["segment_parallel_utils"]=5 ["ring_flash_attention"]=5) #################################### # Insatll paddlepaddle-gpu install_paddle(){ diff --git a/tests/transformers/test_ring_flash_attention.py b/tests/transformers/test_ring_flash_attention.py index 025e1530b0ce..134d2f9c011a 100644 --- a/tests/transformers/test_ring_flash_attention.py +++ b/tests/transformers/test_ring_flash_attention.py @@ -53,6 +53,8 @@ def split_belanced_data(self, input): return paddle.concat([sliced_data0, sliced_data1], axis=1).detach() def single_test(self, bsz, seq_len_per_device, head_num, head_dim, is_causal, use_mask): + if self.degree < 2: + return query, key, value = self.generate_full_data(bsz, seq_len_per_device * self.degree, head_num, head_dim) local_query = self.split_belanced_data(query) @@ -118,9 +120,5 @@ def test_casual_flash_attention(self): self.single_test(2, 1024, 2, 128, True, False) -def main(): - unittest.main() - - if __name__ == "__main__": - paddle.distributed.spawn(main, nprocs=4) + unittest.main()