Add ROCm benchmark workflow for MaxText#755
Conversation
96a4966 to
5e8fd90
Compare
mminutoli
left a comment
There was a problem hiding this comment.
many files are missing the new line at the end of the file.
| WHEELS_URL="${ROCM_WHEELS_BASE_URL}/${WHEELS_PATH%/}" | ||
| echo "Downloading ROCm wheels from ${WHEELS_URL}..." | ||
|
|
||
| LISTING=$(curl -fsSL "${WHEELS_URL}/") |
There was a problem hiding this comment.
Isn't just simpler to download the wheels locally instead of checking if they are list of files available on the page?
if the download of pjrt or plugin fail, then the action fails.
There was a problem hiding this comment.
That’s fair. The listing step is just for discovering the exact nightly filenames; it lets us fetch only the required PJRT + Python-specific plugin, instead of downloading all plugin wheels in that folder.
| run: | | ||
| set -euxo pipefail | ||
| chmod +x "./targets/${TARGET}/run.sh" | ||
| "./targets/${TARGET}/run.sh" --workload "${WORKLOAD}" |
There was a problem hiding this comment.
Maybe instead of an input TARGET should be set through a matrix. Thoughts?
There was a problem hiding this comment.
Yes, we can pass a matrix from the nightly-benchmark workflow.
There was a problem hiding this comment.
I have the feeling they won't like adding all of this stuff.
There was a problem hiding this comment.
Configs folder will live under ROCm/maxtext, along with requirements.txt.
| export XLA_PYTHON_CLIENT_MEM_FRACTION=.97 | ||
| export LD_LIBRARY_PATH=/usr/local/lib/:/opt/rocm/lib:$LD_LIBRARY_PATH | ||
| export NVTE_USE_HIPBLASLT=1 | ||
| export XLA_FLAGS="--xla_gpu_memory_limit_slop_factor=95 --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_enable_command_buffer='' --xla_gpu_enable_latency_hiding_scheduler=True --xla_gpu_all_gather_combine_threshold_bytes=8589934592 --xla_gpu_enable_triton_gemm=False --xla_gpu_enable_cublaslt=True --xla_gpu_autotune_level=4 --xla_gpu_enable_all_gather_combine_by_dim=FALSE" |
There was a problem hiding this comment.
are these specific to ROCm? I guess upstream might want to set their own flags
There was a problem hiding this comment.
Some of these are ROCm-related, but most are really model-run-specific. Since this lives under ROCm/maxtext, it might make sense to keep upstream‑level flags in run.sh or something like ci/envs/default.env, rather than in the config itself.
fff9d8b to
517bb14
Compare
eedae29 to
6f08c2a
Compare
| clone_main_xla: 0 | ||
|
|
||
| run-pytest-rocm: | ||
| run-benchmark-rocm: |
There was a problem hiding this comment.
It will be easier to maintain the amd-main branch if we create a new workflow for downstream and run it with on: schedule instead of removing upstream's nightly workflow and replacing it with our benchmark. Also, I think it's a good idea to keep unit tests and performance benchmarks in separate workflows.
| gcs_download_uri: ${{ inputs.gcs_download_uri }} | ||
| s3_download_uri: ${{ inputs.s3_download_uri }} | ||
| use-te: "1" | ||
| te-wheel-url: "https://github.com/ROCm/maxtext/releases/download/te-rocm-wheels-2026-05-04-86438dc3d04e/transformer_engine-2.12.0.dev0+86438dc3-1.mi355-cp312-cp312-linux_x86_64.whl" |
There was a problem hiding this comment.
Do we want to use the latest TE instead of hardcoding? We should use the gh CLI to do that https://cli.github.com/manual/gh_release_download. I'm pretty sure that the jax-dev image has it installed.
There was a problem hiding this comment.
Yes, I first updated the TE installation flow to resolve the wheel dynamically. I did this using the GitHub API&curl in bash script (See. 9183f20). After testing, I removed the TE dependency entirely from this benchmark workflow because the model can run w/o TE. For the current lightweight benchmark, TE doesnot provide additional value for detecting performance regressions and keeping it removed makes the workflow simpler and avoids extra debugging overhead.
|
|
||
| CFG_FILE="${REPO_DIR}/configs/models/${WORKLOAD}.yml" | ||
| ENV_FILE="${REPO_DIR}/configs/models/${WORKLOAD}.env.sh" | ||
| REQ_FILE="${WORK_DIR}/dependencies/requirements/requirements_rocm_jax_0.8.2.txt" |
There was a problem hiding this comment.
Is there a reason we're still using JAX 0.8.2? Do we want to move to something more recent?
There was a problem hiding this comment.
Yes, at that time I reused the existing requirements file because it was already working for validation/testing, even though the filename postfixed 0_8_2, it actually was using unpinned package versions. I aligned the setup with the planned configuration from ROCm/maxtext#87, so we can move forward with this. MaxText is also planning to provide rocm_extra, which can help remove the need for requirement file later as well.
| BENCHMARK_JSON="${RUN_DIR}/benchmark.json" | ||
| RESULT_JSON="${RUN_DIR}/result.json" | ||
|
|
||
| PYTHON_BIN="${JAXCI_PYTHON:-python3}" |
There was a problem hiding this comment.
Same comment as below about passing things into scripts via environment variables and command line arguments. It breaks modularity and makes scripts hard to understand when we depend on environment variables rather than passing things in via command line arguments.
Passing in command line arguments is more of a pain in Bash, so I'm more okay with letting environment variables slide. But could we at least put all the environment variables that the script expects to use near the top of the script? It still breaks modularity, but at least it's easy for a person reading the script to see which environment variables need to be set.
There was a problem hiding this comment.
Make sense. I kept the exisiting CI/JAX environment variables in the bash script, but grouped the expected ones near the top so the dependencies are more visible.
…tignore (#563) When jaxlib was built in debug more, an assertion in LLVM code that lazy-loads VHLO dialect could fire, since the code path could execute in a multi-threaded environment, and LLVM dialect repositories aren't thread safe to modify. This patch applies the same changes that upstream makes to fix this: jax-ml@48c8762 (this includes disabling a call to `jax_mlir_ext.enter_multi_threaded_execution(context)` in `mlir.py`. Presumably, the whole functionality related to `enter_multi_threaded_execution()` multithreaded checks isn't ready yet, and it was prematurely rolled into the production code. Manual testing
(forgot this skip in the previous PR)
7101e89 to
4a5afee
Compare
555e4c2 to
8405330
Compare
0892795 to
cd5ad81
Compare
cd5ad81 to
34b2d0f
Compare
|
No further action needed, closing the PR. |
No description provided.