diff --git a/.github/scripts/build_triton_wheel.py b/.github/scripts/build_triton_wheel.py index 9d6de81c3aef1..9e86d332c5316 100644 --- a/.github/scripts/build_triton_wheel.py +++ b/.github/scripts/build_triton_wheel.py @@ -162,6 +162,13 @@ def build_triton( cwd=triton_basedir, ) + # For gpt-oss models, triton requires this extra triton_kernels wheel + # triton_kernels came after pytorch release/2.8 + triton_kernels_dir = Path(f"{triton_basedir}/python/triton_kernels") + check_call([sys.executable, "-m", "build", "--wheel"], cwd=triton_kernels_dir, env=env) + kernels_whl_path = next(iter((triton_kernels_dir / "dist").glob("*.whl"))) + shutil.copy(kernels_whl_path, Path.cwd()) + return Path.cwd() / whl_path.name