Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyuqin1998 committed May 19, 2024
1 parent 9a4a518 commit f94a915
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
11 changes: 11 additions & 0 deletions scripts/regression/ci_case.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1111,5 +1111,16 @@ else
echo "only one gpu:${cudaid1} is set, skip test"
fi

}
ring_flash_attention(){
cd ${nlp_dir}
echo "test ring_flash_attention, cudaid1:${cudaid1}, cudaid2:${cudaid2}"
if [[ ${cudaid1} != ${cudaid2} ]]; then
time (python -m paddle.distributed.launch tests/transformers/test_ring_flash_attention.py >${log_path}/ring_flash_attention) >>${log_path}/ring_flash_attention 2>&1
print_info $? ring_flash_attention
else
echo "only one gpu:${cudaid1} is set, skip test"
fi

}
$1
2 changes: 1 addition & 1 deletion scripts/regression/run_ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ all_P0case_dic=(["waybill_ie"]=3 ["msra_ner"]=15 ["glue"]=2 ["bert"]=2 ["skep"]=
["ernie-ctm"]=5 ["distilbert"]=5 ["transformer"]=5 ["pet"]=5 ["efl"]=5 ["p-tuning"]=5 ["ernie-doc"]=20 ["transformer-xl"]=5 \
["question_matching"]=5 ["ernie-csc"]=5 ["nptag"]=5 ["ernie-m"]=5 ["taskflow"]=5 ["clue"]=5 ["textcnn"]=5 \
["fast_generation"]=10 ["ernie-3.0"]=5 ["ernie-layout"]=5 ["uie"]=5 ["ernie-health"]=5 ["llm"]=5 \
["ernie"]=2 ["ernie_m"]=5 ["ernie_layout"]=5 ["ernie_csc"]=5 ["ernie_ctm"]=5 ["ernie_doc"]=20 ["ernie_health"]=5 ["segment_parallel_utils"]=5)
["ernie"]=2 ["ernie_m"]=5 ["ernie_layout"]=5 ["ernie_csc"]=5 ["ernie_ctm"]=5 ["ernie_doc"]=20 ["ernie_health"]=5 ["segment_parallel_utils"]=5 ["ring_flash_attention"]=5)
####################################
# Insatll paddlepaddle-gpu
install_paddle(){
Expand Down
8 changes: 3 additions & 5 deletions tests/transformers/test_ring_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def split_belanced_data(self, input):
return paddle.concat([sliced_data0, sliced_data1], axis=1).detach()

def single_test(self, bsz, seq_len_per_device, head_num, head_dim, is_causal, use_mask):
if self.degree < 2:
return
query, key, value = self.generate_full_data(bsz, seq_len_per_device * self.degree, head_num, head_dim)

local_query = self.split_belanced_data(query)
Expand Down Expand Up @@ -118,9 +120,5 @@ def test_casual_flash_attention(self):
self.single_test(2, 1024, 2, 128, True, False)


def main():
unittest.main()


if __name__ == "__main__":
paddle.distributed.spawn(main, nprocs=4)
unittest.main()

0 comments on commit f94a915

Please sign in to comment.