-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FMHA PAXML test #830
Open
hmonishN
wants to merge
21
commits into
main
Choose a base branch
from
hmonish/add_fmha_paxml_test
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add FMHA PAXML test #830
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
a9a828b
add fmha related changes
hmonishN 69f9db9
Update _test_upstream_pax.yaml
hmonishN 35565ce
Update _sandbox.yaml
hmonishN 42566a6
Update _test_upstream_pax.yaml
hmonishN 495ae32
Update _sandbox.yaml
hmonishN 23a19d8
Update test-pax.sh
hmonishN 4be1f40
Update _sandbox.yaml
hmonishN cb9d7d5
Merge branch 'main' into hmonish/add_fmha_paxml_test
hmonishN f44cdef
removing hlo dir for llama test.
hmonishN da69dbd
Update _test_upstream_pax.yaml
hmonishN a6622c8
Update _test_upstream_pax.yaml
hmonishN b67229b
Update test-pax.sh
hmonishN f7618bf
Update test-pax.sh
hmonishN ab22fcc
Merge branch 'main' into hmonish/add_fmha_paxml_test
hmonishN dbb999d
Update test-pax.sh
hmonishN e0f4b4e
Merge branch 'main' into hmonish/add_fmha_paxml_test
hmonishN 1898b32
Merge branch 'main' into hmonish/add_fmha_paxml_test
hmonishN c1ff8ae
Update _test_pax_rosetta.yaml
hmonishN 0ef811a
Update _test_upstream_pax.yaml
hmonishN ca6e2e9
Update test-pax.sh
hmonishN 7084812
Update _test_upstream_pax.yaml
hmonishN File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,12 +17,13 @@ usage() { | |
echo " --dtype Batch size, defaults to bfloat16." | ||
echo " --enable-te If set, will run with env var ENABLE_TE=1." | ||
echo " --enable-dropout If set, will set DROPOUT_PROB to 0.1." | ||
echo " --disable-fused-attn Whether disable TE fused attention." | ||
echo " --model-type One of 126M, 5B, LLaMA70BProxy. Defaults to 126M" | ||
echo " --evaluate Whether to test evaluation rather than training." | ||
echo " -s, --steps Number of steps to run, defaults to 500." | ||
echo " --multiprocess Enable the multiprocess GPU mode." | ||
echo " -o, --output NAME Name for the output folder, a temporary folder will be created if none specified." | ||
echo " --save-hlo {0, 1} 1 to save the dumped hlo, 0 to remove the hlo dumped folder" | ||
echo " --enable-fmha {0, 1} 1 to enable fmha testing, 0 to run test without fmha; default is 0" | ||
echo " --data-parallel Data parallelism to use. Defaults to 1." | ||
echo " --fsdp Fully-sharded data parallelism to use. Defaults to 1." | ||
echo " --tensor-parallel Tensor parallelism to use. Defaults to 1." | ||
|
@@ -32,7 +33,8 @@ usage() { | |
exit $1 | ||
} | ||
|
||
args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-dropout,disable-fused-attn,model-type:,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@") | ||
args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-dropout,model-type:,enable-fmha:,evaluate,steps:,help,multiprocess,output:,save-hlo:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@") | ||
|
||
if [[ $? -ne 0 ]]; then | ||
exit $1 | ||
fi | ||
|
@@ -55,6 +57,8 @@ NVTE_FUSED_ATTN=1 | |
DROPOUT=0 | ||
EVALUATE=0 | ||
ADDITIONAL_ARGS="" | ||
ENABLE_FMHA=${ENABLE_FMHA:-1} | ||
SAVE_HLO=${SAVE_HLO:-0} | ||
|
||
eval set -- "$args" | ||
while [ : ]; do | ||
|
@@ -75,14 +79,15 @@ while [ : ]; do | |
ENABLE_TE=1 | ||
shift 1 | ||
;; | ||
--enable-fmha) | ||
ENABLE_FMHA="$2" | ||
NVTE_FUSED_ATTN="$2" | ||
shift 2 | ||
;; | ||
--enable-dropout) | ||
DROPOUT='0.1' | ||
shift 1 | ||
;; | ||
--disable-fused-attn) | ||
NVTE_FUSED_ATTN=0 | ||
shift 1 | ||
;; | ||
--model-type) | ||
MODEL_TYPE=$2 | ||
shift 2 | ||
|
@@ -103,6 +108,10 @@ while [ : ]; do | |
OUTPUT=$2 | ||
shift 2 | ||
;; | ||
--save-hlo) | ||
SAVE_HLO="$2" | ||
shift 2 | ||
;; | ||
--data-parallel) | ||
DP="$2" | ||
shift 2 | ||
|
@@ -136,6 +145,21 @@ while [ : ]; do | |
esac | ||
done | ||
|
||
# Set hlo dump folder after output folder is set. | ||
HLO_DIR=${OUTPUT}/hlo | ||
export BASE_XLA_FLAGS="${BASE_XLA_FLAGS:---xla_dump_hlo_as_text --xla_dump_to=${HLO_DIR}}" | ||
export XLA_FLAGS="${BASE_XLA_FLAGS} ${XLA_FLAGS:-}" | ||
echo "HLO will be dumped in ${HLO_DIR} dir." | ||
|
||
## Setting the env variables for FMHA | ||
if [[ "$ENABLE_FMHA" -eq "1" ]]; then | ||
echo "Setting XLA FMHA Flags"; | ||
export BASE_XLA_FLAGS_FMHA="${BASE_XLA_FLAGS_FMHA:---xla_gpu_fused_attention_use_cudnn_rng=true --xla_gpu_enable_cudnn_fmha=true}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Save here as above |
||
export XLA_FLAGS="${BASE_XLA_FLAGS_FMHA} ${XLA_FLAGS:-}" | ||
fi | ||
|
||
echo "XLA FLAGS: $XLA_FLAGS" | ||
|
||
# # Set derived variables | ||
|
||
GPUS_PER_NODE=$(nvidia-smi -L | grep -c '^GPU') | ||
|
@@ -149,8 +173,10 @@ print_var NGPUS | |
print_var OUTPUT | ||
print_var MULTIPROCESS | ||
print_var ENABLE_TE | ||
print_var ENABLE_FMHA | ||
print_var NVTE_FUSED_ATTN | ||
print_var EVALUATE | ||
print_var SAVE_HLO | ||
print_var DROPOUT | ||
print_var DP | ||
print_var FSDP | ||
|
@@ -422,5 +448,25 @@ else | |
$([[ $MULTIPROCESS != 0 ]] && echo --multiprocess_gpu) | ||
fi | ||
|
||
echo "Checking for FMHA instructions in HLO!" | ||
|
||
if [[ "$ENABLE_FMHA" -eq "1" ]]; then | ||
## Check if fmha instructions are present in the HLO dumped file or not. | ||
fmha_regex="fmha[-bmm]?[-scale]?[-bias]?[-mask]?[-softmax]?[-dropout]?[-bmm]?[-backward]?*" | ||
result=$(grep -irlnE "$fmha_regex" "${HLO_DIR}/"*.txt) | ||
|
||
if [ -z "$result" ]; then | ||
echo "E: No FMHA instructions were found in the hlo files!" | ||
exit 1 | ||
else | ||
echo -e "Found FMHA instructions in the following HLO files: \n $result" | ||
fi | ||
fi | ||
|
||
if [[ $SAVE_HLO -eq 0 ]]; then | ||
rm -rf $HLO_DIR | ||
echo "Removed dumped HLO directory!" | ||
fi | ||
|
||
set +x | ||
echo "Output at ${OUTPUT}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please explain logic here: is
BASE_XLA_FLAGS
is set, than you always skip settingHLO_DIR
?If so, maybe you can add a warning message, that xla dump is not set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dumping the hlo is enabled by default in BASE_XLA_FLAGS, and BASE_XLA_FLAGS are appended to XLA_FLAGS env var. if user wants to test fmha then BASE_XLA_FLAGS_FMHA is added and appended to XLA_FLAGS. The idea is to preserve the env var XLA_FLAGS before execution of this script.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, let me clarify my question:
line 150 literally means:
Meaning, that if
BASE_XLA_FLAGS
is already set (by any previous scripts, or globally in the system, etc), ${HLO_DIR} will not have any effect at all.Is that expected behaviour?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And why do you
export
it? You use it only locally.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mechanism was added as per the review comment of same PR for t5x: #442 (comment)
refer to the discussion for details of the implementation.
The implementation BASE_XLA_FLAGS="${BASE_XLA_FLAGS:---xla_dump_hlo_as_text --xla_dump_to=${HLO_DIR}}" means update the BASE_XLA_FLAGS with previous definition if any and append xla dump hlo flags to the env vars. This also gives us the flexibility of "zero out" the env var in this script without modifying code in this script by just doing BASE_XLA_FLAGS=""