From 66f6a53fb5a2d7b008a9a25204111262e2ef17b6 Mon Sep 17 00:00:00 2001 From: "Yaxun (Sam) Liu" Date: Mon, 6 Oct 2025 17:11:21 -0400 Subject: [PATCH] convert HIP struct type vector to llvm vector type --- llvm/lib/Transforms/Scalar/SROA.cpp | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp index becda960a16f0..e6d537f4678ee 100644 --- a/llvm/lib/Transforms/Scalar/SROA.cpp +++ b/llvm/lib/Transforms/Scalar/SROA.cpp @@ -83,6 +83,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" #include "llvm/Transforms/Utils/SSAUpdater.h" #include @@ -5007,6 +5008,34 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS, // FIXME: We might want to defer PHI speculation until after here. // FIXME: return nullptr; } else { + // AMDGPU: If the target is AMDGPU and the chosen SliceTy is a HIP vector + // struct of 2 or 4 identical elements, canonicalize it to an IR vector. + // This helps SROA treat it as a single value and unlock vector ld/st. + // We pattern-match struct names starting with "struct.HIP_vector". + if (Function *F = AI.getFunction()) { + Triple TT(F->getParent()->getTargetTriple()); + if (TT.isAMDGPU()) { + if (auto *STy = dyn_cast(SliceTy)) { + StringRef Name = STy->hasName() ? STy->getName() : StringRef(); + if (Name.starts_with("struct.HIP_vector")) { + unsigned NumElts = STy->getNumElements(); + if ((NumElts == 2 || NumElts == 4) && NumElts > 0) { + Type *EltTy = STy->getElementType(0); + bool AllSame = true; + for (unsigned I = 1; I < NumElts; ++I) + if (STy->getElementType(I) != EltTy) { + AllSame = false; + break; + } + if (AllSame && VectorType::isValidElementType(EltTy)) { + SliceTy = FixedVectorType::get(EltTy, NumElts); + } + } + } + } + } + } + // Make sure the alignment is compatible with P.beginOffset(). const Align Alignment = commonAlignment(AI.getAlign(), P.beginOffset()); // If we will get at least this much alignment from the type alone, leave