Skip to content

[DirectX] Scalarize extractelement and insertelement with dynamic indices #141676

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
67 changes: 49 additions & 18 deletions llvm/lib/Target/DirectX/DXILDataScalarization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@ static const int MaxVecSize = 4;

using namespace llvm;

// Recursively creates an array-like version of a given vector type.
static Type *equivalentArrayTypeFromVector(Type *T) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine I would have just made a function declaration at the top so the implementation could live anywhere.

if (auto *VecTy = dyn_cast<VectorType>(T))
return ArrayType::get(VecTy->getElementType(),
dyn_cast<FixedVectorType>(VecTy)->getNumElements());
if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
Type *NewElementType =
equivalentArrayTypeFromVector(ArrayTy->getElementType());
return ArrayType::get(NewElementType, ArrayTy->getNumElements());
}
// If it's not a vector or array, return the original type.
return T;
}

class DXILDataScalarizationLegacy : public ModulePass {

public:
Expand Down Expand Up @@ -55,7 +69,7 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
bool visitCastInst(CastInst &CI) { return false; }
bool visitBitCastInst(BitCastInst &BCI) { return false; }
bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
bool visitExtractElementInst(ExtractElementInst &EEI);
bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
bool visitPHINode(PHINode &PHI) { return false; }
bool visitLoadInst(LoadInst &LI);
Expand Down Expand Up @@ -90,20 +104,6 @@ DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
return nullptr; // Not found
}

// Recursively creates an array version of the given vector type.
static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
if (auto *VecTy = dyn_cast<VectorType>(T))
return ArrayType::get(VecTy->getElementType(),
dyn_cast<FixedVectorType>(VecTy)->getNumElements());
if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
Type *NewElementType =
replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
return ArrayType::get(NewElementType, ArrayTy->getNumElements());
}
// If it's not a vector or array, return the original type.
return T;
}

static bool isArrayOfVectors(Type *T) {
if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
return isa<VectorType>(ArrType->getElementType());
Expand All @@ -116,8 +116,7 @@ bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) {

ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
IRBuilder<> Builder(&AI);
LLVMContext &Ctx = AI.getContext();
Type *NewType = replaceVectorWithArray(ArrType, Ctx);
Type *NewType = equivalentArrayTypeFromVector(ArrType);
AllocaInst *ArrAlloca =
Builder.CreateAlloca(NewType, nullptr, AI.getName() + ".scalarize");
ArrAlloca->setAlignment(AI.getAlign());
Expand Down Expand Up @@ -173,6 +172,38 @@ bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
return false;
}

bool DataScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
// If the index is a constant then we don't need to scalarize it
Value *Index = EEI.getIndexOperand();
Type *IndexTy = Index->getType();
if (isa<ConstantInt>(Index))
return false;

IRBuilder<> Builder(&EEI);
VectorType *VecTy = EEI.getVectorOperandType();
assert(VecTy->getElementCount().isFixed() &&
"Vector operand of ExtractElement must have a fixed size");

Type *ArrTy = equivalentArrayTypeFromVector(VecTy);
Value *ArrAlloca = Builder.CreateAlloca(ArrTy);

for (unsigned I = 0; I < ArrTy->getArrayNumElements(); ++I) {
Value *EE = Builder.CreateExtractElement(EEI.getVectorOperand(), I);
Value *GEP = Builder.CreateInBoundsGEP(
ArrTy, ArrAlloca,
{ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, I)});
Builder.CreateStore(EE, GEP);
}

Value *GEP = Builder.CreateInBoundsGEP(ArrTy, ArrAlloca,
{ConstantInt::get(IndexTy, 0), Index});
Value *Load = Builder.CreateLoad(ArrTy->getArrayElementType(), GEP);

EEI.replaceAllUsesWith(Load);
EEI.eraseFromParent();
return true;
}

bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {

unsigned NumOperands = GEPI.getNumOperands();
Expand Down Expand Up @@ -257,7 +288,7 @@ static bool findAndReplaceVectors(Module &M) {
for (GlobalVariable &G : M.globals()) {
Type *OrigType = G.getValueType();

Type *NewType = replaceVectorWithArray(OrigType, Ctx);
Type *NewType = equivalentArrayTypeFromVector(OrigType);
if (OrigType != NewType) {
// Create a new global variable with the updated type
// Note: Initializer is set via transformInitializer
Expand Down
38 changes: 38 additions & 0 deletions llvm/test/CodeGen/DirectX/scalarize-dynamic-vector-index.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -passes='dxil-data-scalarization' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s

define float @extract_float_vec_dynamic(<4 x float> %0, i32 %1) {
; CHECK-LABEL: define float @extract_float_vec_dynamic(
; CHECK-SAME: <4 x float> [[TMP0:%.*]], i32 [[TMP1:%.*]]) {
; CHECK-NEXT: [[TMP3:%.*]] = alloca [4 x float], align 4
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x float> [[TMP0]], i64 0
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds [4 x float], ptr [[TMP3]], i32 0, i32 0
; CHECK-NEXT: store float [[TMP4]], ptr [[TMP5]], align 4
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x float> [[TMP0]], i64 1
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds [4 x float], ptr [[TMP3]], i32 0, i32 1
; CHECK-NEXT: store float [[TMP6]], ptr [[TMP7]], align 4
; CHECK-NEXT: [[TMP8:%.*]] = extractelement <4 x float> [[TMP0]], i64 2
; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds [4 x float], ptr [[TMP3]], i32 0, i32 2
; CHECK-NEXT: store float [[TMP8]], ptr [[TMP9]], align 4
; CHECK-NEXT: [[TMP10:%.*]] = extractelement <4 x float> [[TMP0]], i64 3
; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds [4 x float], ptr [[TMP3]], i32 0, i32 3
; CHECK-NEXT: store float [[TMP10]], ptr [[TMP11]], align 4
; CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [4 x float], ptr [[TMP3]], i32 0, i32 [[TMP1]]
; CHECK-NEXT: [[TMP13:%.*]] = load float, ptr [[TMP12]], align 4
; CHECK-NEXT: ret float [[TMP13]]
;
%e = extractelement <4 x float> %0, i32 %1
ret float %e
}

; An extractelement with a constant index should not be converted to array form
define i16 @extract_i16_vec_constant(<4 x i16> %0) {
; CHECK-LABEL: define i16 @extract_i16_vec_constant(
; CHECK-SAME: <4 x i16> [[TMP0:%.*]]) {
; CHECK-NEXT: [[E:%.*]] = extractelement <4 x i16> [[TMP0]], i32 1
; CHECK-NEXT: ret i16 [[E]]
;
%e = extractelement <4 x i16> %0, i32 1
ret i16 %e
}

Loading