Skip to content

Add ROCm benchmark workflow for MaxText#755

Closed
psanal35 wants to merge 50 commits into
amd-mainfrom
add-rocm-model-benchmarks
Closed

Add ROCm benchmark workflow for MaxText#755
psanal35 wants to merge 50 commits into
amd-mainfrom
add-rocm-model-benchmarks

Conversation

@psanal35
Copy link
Copy Markdown

No description provided.

@psanal35 psanal35 force-pushed the add-rocm-model-benchmarks branch 5 times, most recently from 96a4966 to 5e8fd90 Compare April 23, 2026 18:58
Copy link
Copy Markdown

@mminutoli mminutoli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

many files are missing the new line at the end of the file.

WHEELS_URL="${ROCM_WHEELS_BASE_URL}/${WHEELS_PATH%/}"
echo "Downloading ROCm wheels from ${WHEELS_URL}..."

LISTING=$(curl -fsSL "${WHEELS_URL}/")
Copy link
Copy Markdown

@mminutoli mminutoli Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't just simpler to download the wheels locally instead of checking if they are list of files available on the page?

if the download of pjrt or plugin fail, then the action fails.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That’s fair. The listing step is just for discovering the exact nightly filenames; it lets us fetch only the required PJRT + Python-specific plugin, instead of downloading all plugin wheels in that folder.

Comment thread .github/workflows/benchmark_rocm.yml Outdated
run: |
set -euxo pipefail
chmod +x "./targets/${TARGET}/run.sh"
"./targets/${TARGET}/run.sh" --workload "${WORKLOAD}"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe instead of an input TARGET should be set through a matrix. Thoughts?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can pass a matrix from the nightly-benchmark workflow.

Comment thread targets/maxtext/configs/llama3_8b.yml Outdated
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the feeling they won't like adding all of this stuff.

Copy link
Copy Markdown
Author

@psanal35 psanal35 Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Configs folder will live under ROCm/maxtext, along with requirements.txt.

export XLA_PYTHON_CLIENT_MEM_FRACTION=.97
export LD_LIBRARY_PATH=/usr/local/lib/:/opt/rocm/lib:$LD_LIBRARY_PATH
export NVTE_USE_HIPBLASLT=1
export XLA_FLAGS="--xla_gpu_memory_limit_slop_factor=95 --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_enable_command_buffer='' --xla_gpu_enable_latency_hiding_scheduler=True --xla_gpu_all_gather_combine_threshold_bytes=8589934592 --xla_gpu_enable_triton_gemm=False --xla_gpu_enable_cublaslt=True --xla_gpu_autotune_level=4 --xla_gpu_enable_all_gather_combine_by_dim=FALSE"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these specific to ROCm? I guess upstream might want to set their own flags

Copy link
Copy Markdown
Author

@psanal35 psanal35 Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of these are ROCm-related, but most are really model-run-specific. Since this lives under ROCm/maxtext, it might make sense to keep upstream‑level flags in run.sh or something like ci/envs/default.env, rather than in the config itself.

@psanal35 psanal35 force-pushed the add-rocm-model-benchmarks branch 6 times, most recently from fff9d8b to 517bb14 Compare April 24, 2026 20:07
@psanal35 psanal35 force-pushed the add-rocm-model-benchmarks branch 2 times, most recently from eedae29 to 6f08c2a Compare May 4, 2026 19:44
clone_main_xla: 0

run-pytest-rocm:
run-benchmark-rocm:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be easier to maintain the amd-main branch if we create a new workflow for downstream and run it with on: schedule instead of removing upstream's nightly workflow and replacing it with our benchmark. Also, I think it's a good idea to keep unit tests and performance benchmarks in separate workflows.

Comment thread ci/collect_run_manifest_rocm.py Outdated
Comment thread ci/collect_run_manifest_rocm.py Outdated
gcs_download_uri: ${{ inputs.gcs_download_uri }}
s3_download_uri: ${{ inputs.s3_download_uri }}
use-te: "1"
te-wheel-url: "https://github.com/ROCm/maxtext/releases/download/te-rocm-wheels-2026-05-04-86438dc3d04e/transformer_engine-2.12.0.dev0+86438dc3-1.mi355-cp312-cp312-linux_x86_64.whl"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to use the latest TE instead of hardcoding? We should use the gh CLI to do that https://cli.github.com/manual/gh_release_download. I'm pretty sure that the jax-dev image has it installed.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I first updated the TE installation flow to resolve the wheel dynamically. I did this using the GitHub API&curl in bash script (See. 9183f20). After testing, I removed the TE dependency entirely from this benchmark workflow because the model can run w/o TE. For the current lightweight benchmark, TE doesnot provide additional value for detecting performance regressions and keeping it removed makes the workflow simpler and avoids extra debugging overhead.


CFG_FILE="${REPO_DIR}/configs/models/${WORKLOAD}.yml"
ENV_FILE="${REPO_DIR}/configs/models/${WORKLOAD}.env.sh"
REQ_FILE="${WORK_DIR}/dependencies/requirements/requirements_rocm_jax_0.8.2.txt"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we're still using JAX 0.8.2? Do we want to move to something more recent?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, at that time I reused the existing requirements file because it was already working for validation/testing, even though the filename postfixed 0_8_2, it actually was using unpinned package versions. I aligned the setup with the planned configuration from ROCm/maxtext#87, so we can move forward with this. MaxText is also planning to provide rocm_extra, which can help remove the need for requirement file later as well.

Comment thread ci/collect_run_manifest_rocm.py Outdated
Comment thread ci/benchmark_targets/maxtext_rocm/run_maxtext_rocm.sh Outdated
BENCHMARK_JSON="${RUN_DIR}/benchmark.json"
RESULT_JSON="${RUN_DIR}/result.json"

PYTHON_BIN="${JAXCI_PYTHON:-python3}"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as below about passing things into scripts via environment variables and command line arguments. It breaks modularity and makes scripts hard to understand when we depend on environment variables rather than passing things in via command line arguments.

Passing in command line arguments is more of a pain in Bash, so I'm more okay with letting environment variables slide. But could we at least put all the environment variables that the script expects to use near the top of the script? It still breaks modularity, but at least it's easy for a person reading the script to see which environment variables need to be set.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense. I kept the exisiting CI/JAX environment variables in the bash script, but grouped the expected ones near the top so the dependencies are more visible.

charleshofer and others added 15 commits May 6, 2026 13:07
…tignore (#563)

When jaxlib was built in debug more, an assertion in LLVM code that lazy-loads VHLO dialect could fire, since the code path could execute in a multi-threaded environment, and LLVM dialect repositories aren't thread safe to modify.

This patch applies the same changes that upstream makes to fix this: jax-ml@48c8762

(this includes disabling a call to `jax_mlir_ext.enter_multi_threaded_execution(context)` in `mlir.py`. Presumably, the whole functionality related to `enter_multi_threaded_execution()` multithreaded checks isn't ready yet, and it was prematurely rolled into the production code.

Manual testing
(forgot this skip in the previous PR)
@psanal35 psanal35 force-pushed the add-rocm-model-benchmarks branch 5 times, most recently from 7101e89 to 4a5afee Compare May 10, 2026 22:22
@psanal35 psanal35 force-pushed the add-rocm-model-benchmarks branch 5 times, most recently from 555e4c2 to 8405330 Compare May 11, 2026 01:48
@psanal35 psanal35 force-pushed the add-rocm-model-benchmarks branch 6 times, most recently from 0892795 to cd5ad81 Compare May 11, 2026 15:07
@psanal35 psanal35 force-pushed the add-rocm-model-benchmarks branch from cd5ad81 to 34b2d0f Compare May 11, 2026 17:35
@psanal35
Copy link
Copy Markdown
Author

No further action needed, closing the PR.

@psanal35 psanal35 closed this May 20, 2026
@psanal35 psanal35 deleted the add-rocm-model-benchmarks branch May 20, 2026 03:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.