Skip to content

Commit

Permalink
PR #10811: Restore ReshapeDecomposer And LayoutNormalization Passes i…
Browse files Browse the repository at this point in the history
…n GPU Compiler

Imported from GitHub PR #10811

Restores individual ReshapeDecomposer and LayoutNormalization passes in the GPU compiler previously removed in #9852 and fixes failures in cudnn_norm_rewriter_test.cc.
Copybara import of the project:

--
5206571 by Philipp Hack <phack@nvidia.com>:

Restore ReshapeDecomposer and LayoutNormalization passes.

Merging this change closes #10811

COPYBARA_INTEGRATE_REVIEW=#10811 from philipphack:u_layer_reshape_decomposer_xla 5206571
PiperOrigin-RevId: 618059598
  • Loading branch information
philipphack authored and copybara-github committed Mar 22, 2024
1 parent db867c2 commit c408b63
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
pipeline.AddPass<TransposeFolding>(CanFoldTransposeOperandIntoDot,
TransposeFolding::NeverFoldTranspose);

pipeline.AddPass<ReshapeDecomposer>();
pipeline.AddPass<ReduceDecomposer>([&](const HloInstruction* r) {
return IsReductionFromOrToContiguousDimensions(*r);
});
Expand Down Expand Up @@ -1418,6 +1419,9 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
// Rewrite GEMMs with broadcasted inputs as strided GEMMs.
pipeline.AddPass<GemmBroadcastFoldingRewriter>();

if (debug_options.xla_gpu_normalize_layouts()) {
pipeline.AddPass<LayoutNormalization>(&NormalizeLayoutForGpuCustomCalls);
}
pipeline.AddPass<BroadcastCanonicalizer>();

pipeline.AddPass<ReductionDegenerateDimRemover>();
Expand Down

0 comments on commit c408b63

Please sign in to comment.