Skip to content

Added Dockerfile for CI images & Upgrate CI to ROCm 7.2#195

Merged
VeeraRajasekhar merged 11 commits intodevfrom
dockerfile
Feb 24, 2026
Merged

Added Dockerfile for CI images & Upgrate CI to ROCm 7.2#195
VeeraRajasekhar merged 11 commits intodevfrom
dockerfile

Conversation

@VeeraRajasekhar
Copy link
Copy Markdown
Contributor

Description

Added the dockerfile, which can be used to create the ci-artifactory images.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Added a new file docker/Dockerfile

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Comment thread docker/Dockerfile Outdated
Comment thread docker/Dockerfile Outdated
Comment thread docker/Dockerfile Outdated
Copy link
Copy Markdown
Collaborator

@wenchenvincent wenchenvincent left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please address the comments.

Comment thread docker/Dockerfile Outdated
Comment thread docker/Dockerfile Outdated
Comment thread docker/Dockerfile Outdated
Comment thread docker/Dockerfile Outdated
Comment thread docker/Dockerfile Outdated
Comment thread docker/Dockerfile Outdated
Comment thread docker/Dockerfile Outdated
Comment thread docker/Dockerfile Outdated
Comment thread docker/Dockerfile Outdated
Copy link
Copy Markdown
Collaborator

@ipanfilo ipanfilo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why conversations are marked as resolved w/o any actual action?

@VeeraRajasekhar
Copy link
Copy Markdown
Contributor Author

Why conversations are marked as resolved w/o any actual action?

Some of them, I have resolved, some I have currently resolved in my local, just to keep track I will mark them resolved.

Comment thread docker/Dockerfile.ci.deps Outdated
Comment thread docker/Dockerfile.ci.deps Outdated
Comment thread docker/Dockerfile.ci.deps Outdated
@wenchenvincent
Copy link
Copy Markdown
Collaborator

@VeeraRajasekhar Is this PR still needed?

@wenchenvincent
Copy link
Copy Markdown
Collaborator

@VeeraRajasekhar Could you remind me of what we had decided on this PR? It seemed that it is no longer relevant and we should close it.

@VeeraRajasekhar
Copy link
Copy Markdown
Contributor Author

Hi @ipanfilo, @wangye805

I have updated this PR with latest 7.2 docker file and moved to .github/scripts.

Let me know if I need to add an action to automate docker build and upload to our artifactory?

Thanks.

Comment thread .github/scripts/Dockerfile.ci.deps Outdated
Comment thread .github/scripts/Dockerfile.ci.deps Outdated
Comment thread .github/scripts/Dockerfile.ci.deps Outdated
Comment thread .github/scripts/Dockerfile.ci.deps Outdated
@VeeraRajasekhar
Copy link
Copy Markdown
Contributor Author

I had to force push to include new FA 2.8.3 support commit and my changes for 7.2 support to run the CI.

Thanks.

@VeeraRajasekhar
Copy link
Copy Markdown
Contributor Author

@Micky774, please review the following,

Analysis on testing on Jax & xla 0.8.2

(Not Supported) jax.nn.scaled_matmul (MXFP8) on ROCm crashes with a segmentation fault when the contracting dimension (K) is less than 64.

import functools
import jax
import jax.numpy as jnp
from jax import nn

key = jax.random.PRNGKey(0)
key_a, key_b = jax.random.split(key)
B, M, N, K = 1, 128, 128, 32


lhs = jax.random.normal(key_a, (B, M, K), dtype=jnp.float32)
rhs = jax.random.normal(key_b, (B, N, K), dtype=jnp.float32)

# 1. high-precision matmul
ref = jnp.einsum("bmk,bnk->bmn", lhs, rhs)

# 2. mxfp8 matmul
configs = [nn.get_scaled_dot_general_config("mxfp8")] * 3
scaled_dot = functools.partial(
    nn.scaled_dot_general,
    configs=configs,
    preferred_element_type=jnp.float32,
)

out = scaled_dot(lhs, rhs, (((2,), (2,)), ((0,), (0,))))

# compare results
print("high-precision ref: ")
print(ref)

print("mxfp8 out: ")
print(out)

max_abs = jnp.max(jnp.abs(out - ref))
max_rel = max_abs / jnp.max(jnp.abs(ref))

print("max abs error:", max_abs)
print("max rel error:", max_rel)

[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_32x32x64_UR_2: K must be a multiple of workgroupTile.k=64 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_16x16x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_32x16x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_16x32x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_32x32x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_64x16x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_16x64x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_64x32x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_32x64x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_16x16x256_UR_2: K must be a multiple of workgroupTile.k=256 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_64x64x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_128x32x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_32x128x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_128x64x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_64x128x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_16x64x256_UR_2: K must be a multiple of workgroupTile.k=256 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_192x32x128_UR_2: M must be a multiple of workgroupTile.m=192
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_192x32x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_128x128x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_32x192x128_UR_2: N must be a multiple of workgroupTile.n=192
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_32x192x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_192x64x128_UR_2: M must be a multiple of workgroupTile.m=192
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_192x64x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_64x192x128_UR_2: N must be a multiple of workgroupTile.n=192
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_64x192x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
Segmentation fault (core dumped)


@VeeraRajasekhar
Copy link
Copy Markdown
Contributor Author

@Micky774
Copy link
Copy Markdown
Contributor

Micky774 commented Feb 16, 2026

@Micky774, please review the following,

Analysis on testing on Jax & xla 0.8.2

(Not Supported) jax.nn.scaled_matmul (MXFP8) on ROCm crashes with a segmentation fault when the contracting dimension (K) is less than 64.
...

This is a failure on certain configs for hipblaslt, which we already have tickets open for. We don't support these configs in TE anyways, so it's a known issue and not a blocker. Shouldn't be a problem for this PR.

@VeeraRajasekhar VeeraRajasekhar changed the title Added Dockerfile for CI images Added Dockerfile for CI images & Upgrate CI to ROCm 7.2 Feb 16, 2026
Comment thread tests/jax/utils.py
Comment thread tests/jax/utils.py
Comment thread .github/scripts/Dockerfile.ci.deps Outdated
Comment thread .github/scripts/Dockerfile.ci.deps Outdated
Comment thread .github/scripts/Dockerfile.ci.deps
@Micky774
Copy link
Copy Markdown
Contributor

@VeeraRajasekhar you'll need to merge w/ dev to fix CI

@VeeraRajasekhar
Copy link
Copy Markdown
Contributor Author

ROCm's jax.nn.scaled_matmul kernels require the contracting dimension (K)
to be at least 64. Without this validation, backward pass GEMMs with K < 64
cause segmentation faults.

Added K >= 64 check in _check_mxfp8_gemm_support() for JAX GEMM on ROCm.

Fixes: test_dense_grad_fp8[MXFP8_1D_SCALING-with_jax_gemm_True-64-32-64]
@VeeraRajasekhar VeeraRajasekhar merged commit b685686 into dev Feb 24, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants