Skip to content

[Matrix] Propagate shape information through cast insts #141869

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
#include "llvm/IR/CFG.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/MatrixBuilder.h"
Expand Down Expand Up @@ -247,6 +249,34 @@ static bool isUniformShape(Value *V) {
if (I->isBinaryOp())
return true;

if (auto *Cast = dyn_cast<CastInst>(V)) {
switch (Cast->getOpcode()) {
case llvm::Instruction::Trunc:
case llvm::Instruction::ZExt:
case llvm::Instruction::SExt:
case llvm::Instruction::FPToUI:
case llvm::Instruction::FPToSI:
case llvm::Instruction::UIToFP:
case llvm::Instruction::SIToFP:
case llvm::Instruction::FPTrunc:
case llvm::Instruction::FPExt:
return true;
case llvm::Instruction::AddrSpaceCast:
case CastInst::PtrToInt:
case CastInst::IntToPtr:
return false;
case CastInst::BitCast: {
if (auto *SrcVTy = dyn_cast<FixedVectorType>(Cast->getSrcTy()))
if (auto *DestVTy = dyn_cast<FixedVectorType>(Cast->getDestTy()))
return SrcVTy->getNumElements() == DestVTy->getNumElements();
return false;
}
case llvm::Instruction::CastOpsEnd:
llvm_unreachable("not an actual cast op");
}
llvm_unreachable("unhandled cast opcode");
}

if (auto *II = dyn_cast<IntrinsicInst>(V))
switch (II->getIntrinsicID()) {
case Intrinsic::abs:
Expand Down Expand Up @@ -1110,6 +1140,8 @@ class LowerMatrixIntrinsics {
Value *Op2;
if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
VisitBinaryOperator(BinOp, SI);
else if (auto *Cast = dyn_cast<CastInst>(Inst))
VisitCastInstruction(Cast, SI);
else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
VisitUnaryOperator(UnOp, SI);
else if (IntrinsicInst *Intr = dyn_cast<IntrinsicInst>(Inst))
Expand Down Expand Up @@ -2260,6 +2292,30 @@ class LowerMatrixIntrinsics {
Builder);
}

/// Lower cast instructions.
void VisitCastInstruction(CastInst *Inst, const ShapeInfo &Shape) {
Value *Op = Inst->getOperand(0);

IRBuilder<> Builder(Inst);

MatrixTy Result;
MatrixTy M = getMatrix(Op, Shape, Builder);

Builder.setFastMathFlags(getFastMathFlags(Inst));

auto *OrigVTy = cast<VectorType>(Inst->getType());
auto *NewVTy = VectorType::get(OrigVTy->getElementType(),
ElementCount::getFixed(M.getStride()));

for (auto &Vector : M.vectors())
Result.addVector(Builder.CreateCast(Inst->getOpcode(), Vector, NewVTy));

finalizeLowering(Inst,
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
Result.getNumVectors()),
Builder);
}

/// Helper to linearize a matrix expression tree into a string. Currently
/// matrix expressions are linarized by starting at an expression leaf and
/// linearizing bottom up.
Expand Down
250 changes: 250 additions & 0 deletions llvm/test/Transforms/LowerMatrixIntrinsics/unary.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s

define void @fneg_2x2(ptr %in, ptr %out) {
; CHECK-LABEL: @fneg_2x2(
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
; CHECK-NEXT: [[TMP1:%.*]] = fneg <2 x float> [[COL_LOAD]]
; CHECK-NEXT: [[TMP2:%.*]] = fneg <2 x float> [[COL_LOAD1]]
; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 4
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr float, ptr [[OUT]], i64 2
; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP2]], align 4
; CHECK-NEXT: ret void
;
%inv = load <4 x float>, ptr %in
%op = fneg <4 x float> %inv
call void @llvm.matrix.column.major.store(<4 x float> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
ret void
}

define void @trunc_2x2(ptr %in, ptr %out) {
; CHECK-LABEL: @trunc_2x2(
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x i64>, ptr [[IN:%.*]], align 32
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i64, ptr [[IN]], i64 2
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x i64>, ptr [[VEC_GEP]], align 16
; CHECK-NEXT: [[TMP1:%.*]] = trunc <2 x i64> [[COL_LOAD]] to <2 x i32>
; CHECK-NEXT: [[TMP2:%.*]] = trunc <2 x i64> [[COL_LOAD1]] to <2 x i32>
; CHECK-NEXT: store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 4
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
; CHECK-NEXT: store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 4
; CHECK-NEXT: ret void
;
%inv = load <4 x i64>, ptr %in
%op = trunc <4 x i64> %inv to <4 x i32>
call void @llvm.matrix.column.major.store(<4 x i32> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
ret void
}

define void @zext_2x2(ptr %in, ptr %out) {
; CHECK-LABEL: @zext_2x2(
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x i16>, ptr [[IN:%.*]], align 8
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i16, ptr [[IN]], i64 2
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x i16>, ptr [[VEC_GEP]], align 4
; CHECK-NEXT: [[TMP1:%.*]] = zext <2 x i16> [[COL_LOAD]] to <2 x i32>
; CHECK-NEXT: [[TMP2:%.*]] = zext <2 x i16> [[COL_LOAD1]] to <2 x i32>
; CHECK-NEXT: store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 4
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
; CHECK-NEXT: store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 4
; CHECK-NEXT: ret void
;
%inv = load <4 x i16>, ptr %in
%op = zext <4 x i16> %inv to <4 x i32>
call void @llvm.matrix.column.major.store(<4 x i32> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
ret void
}

define void @sext_2x2(ptr %in, ptr %out) {
; CHECK-LABEL: @sext_2x2(
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x i8>, ptr [[IN:%.*]], align 4
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i8, ptr [[IN]], i64 2
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x i8>, ptr [[VEC_GEP]], align 2
; CHECK-NEXT: [[TMP1:%.*]] = sext <2 x i8> [[COL_LOAD]] to <2 x i16>
; CHECK-NEXT: [[TMP2:%.*]] = sext <2 x i8> [[COL_LOAD1]] to <2 x i16>
; CHECK-NEXT: store <2 x i16> [[TMP1]], ptr [[OUT:%.*]], align 2
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i16, ptr [[OUT]], i64 2
; CHECK-NEXT: store <2 x i16> [[TMP2]], ptr [[VEC_GEP2]], align 2
; CHECK-NEXT: ret void
;
%inv = load <4 x i8>, ptr %in
%op = sext <4 x i8> %inv to <4 x i16>
call void @llvm.matrix.column.major.store(<4 x i16> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
ret void
}

define void @fptoui_2x2(ptr %in, ptr %out) {
; CHECK-LABEL: @fptoui_2x2(
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
; CHECK-NEXT: [[TMP1:%.*]] = fptoui <2 x float> [[COL_LOAD]] to <2 x i32>
; CHECK-NEXT: [[TMP2:%.*]] = fptoui <2 x float> [[COL_LOAD1]] to <2 x i32>
; CHECK-NEXT: store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 4
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
; CHECK-NEXT: store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 4
; CHECK-NEXT: ret void
;
%inv = load <4 x float>, ptr %in
%op = fptoui <4 x float> %inv to <4 x i32>
call void @llvm.matrix.column.major.store(<4 x i32> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
ret void
}

define void @fptosi_2x2(ptr %in, ptr %out) {
; CHECK-LABEL: @fptosi_2x2(
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
; CHECK-NEXT: [[TMP1:%.*]] = fptosi <2 x float> [[COL_LOAD]] to <2 x i32>
; CHECK-NEXT: [[TMP2:%.*]] = fptosi <2 x float> [[COL_LOAD1]] to <2 x i32>
; CHECK-NEXT: store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 4
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
; CHECK-NEXT: store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 4
; CHECK-NEXT: ret void
;
%inv = load <4 x float>, ptr %in
%op = fptosi <4 x float> %inv to <4 x i32>
call void @llvm.matrix.column.major.store(<4 x i32> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
ret void
}

define void @uitofp_2x2(ptr %in, ptr %out) {
; CHECK-LABEL: @uitofp_2x2(
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x i64>, ptr [[IN:%.*]], align 32
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i64, ptr [[IN]], i64 2
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x i64>, ptr [[VEC_GEP]], align 16
; CHECK-NEXT: [[TMP1:%.*]] = uitofp <2 x i64> [[COL_LOAD]] to <2 x double>
; CHECK-NEXT: [[TMP2:%.*]] = uitofp <2 x i64> [[COL_LOAD1]] to <2 x double>
; CHECK-NEXT: store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 8
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[OUT]], i64 2
; CHECK-NEXT: store <2 x double> [[TMP2]], ptr [[VEC_GEP2]], align 8
; CHECK-NEXT: ret void
;
%inv = load <4 x i64>, ptr %in
%op = uitofp <4 x i64> %inv to <4 x double>
call void @llvm.matrix.column.major.store(<4 x double> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
ret void
}

define void @sitofp_2x2(ptr %in, ptr %out) {
; CHECK-LABEL: @sitofp_2x2(
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x i64>, ptr [[IN:%.*]], align 32
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i64, ptr [[IN]], i64 2
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x i64>, ptr [[VEC_GEP]], align 16
; CHECK-NEXT: [[TMP1:%.*]] = sitofp <2 x i64> [[COL_LOAD]] to <2 x double>
; CHECK-NEXT: [[TMP2:%.*]] = sitofp <2 x i64> [[COL_LOAD1]] to <2 x double>
; CHECK-NEXT: store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 8
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[OUT]], i64 2
; CHECK-NEXT: store <2 x double> [[TMP2]], ptr [[VEC_GEP2]], align 8
; CHECK-NEXT: ret void
;
%inv = load <4 x i64>, ptr %in
%op = sitofp <4 x i64> %inv to <4 x double>
call void @llvm.matrix.column.major.store(<4 x double> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
ret void
}

define void @fptrunc_2x2(ptr %in, ptr %out) {
; CHECK-LABEL: @fptrunc_2x2(
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x double>, ptr [[IN:%.*]], align 32
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN]], i64 2
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 16
; CHECK-NEXT: [[TMP1:%.*]] = fptrunc nnan <2 x double> [[COL_LOAD]] to <2 x float>
; CHECK-NEXT: [[TMP2:%.*]] = fptrunc nnan <2 x double> [[COL_LOAD1]] to <2 x float>
; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 4
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr float, ptr [[OUT]], i64 2
; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP2]], align 4
; CHECK-NEXT: ret void
;
%inv = load <4 x double>, ptr %in
%op = fptrunc nnan <4 x double> %inv to <4 x float>
call void @llvm.matrix.column.major.store(<4 x float> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
ret void
}

define void @fpext_2x2(ptr %in, ptr %out) {
; CHECK-LABEL: @fpext_2x2(
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
; CHECK-NEXT: [[TMP1:%.*]] = fpext <2 x float> [[COL_LOAD]] to <2 x double>
; CHECK-NEXT: [[TMP2:%.*]] = fpext <2 x float> [[COL_LOAD1]] to <2 x double>
; CHECK-NEXT: store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 8
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[OUT]], i64 2
; CHECK-NEXT: store <2 x double> [[TMP2]], ptr [[VEC_GEP2]], align 8
; CHECK-NEXT: ret void
;
%inv = load <4 x float>, ptr %in
%op = fpext <4 x float> %inv to <4 x double>
call void @llvm.matrix.column.major.store(<4 x double> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
ret void
}

define void @bitcast_2x2_v4f64_to_v4i64(ptr %in, ptr %out) {
; CHECK-LABEL: @bitcast_2x2_v4f64_to_v4i64(
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x double>, ptr [[IN:%.*]], align 32
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN]], i64 2
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 16
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x double> [[COL_LOAD]] to <2 x i64>
; CHECK-NEXT: [[TMP2:%.*]] = bitcast <2 x double> [[COL_LOAD1]] to <2 x i64>
; CHECK-NEXT: store <2 x i64> [[TMP1]], ptr [[OUT:%.*]], align 4
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i64, ptr [[OUT]], i64 2
; CHECK-NEXT: store <2 x i64> [[TMP2]], ptr [[VEC_GEP2]], align 4
; CHECK-NEXT: ret void
;
%inv = load <4 x double>, ptr %in
%op = bitcast <4 x double> %inv to <4 x i64>
call void @llvm.matrix.column.major.store(<4 x i64> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
ret void
}

define void @bitcast_2x2_v4f64_to_v8i32(ptr %in, ptr %out) {
; CHECK-LABEL: @bitcast_2x2_v4f64_to_v8i32(
; CHECK-NEXT: [[INV:%.*]] = load <4 x double>, ptr [[IN:%.*]], align 32
; CHECK-NEXT: [[OP:%.*]] = bitcast <4 x double> [[INV]] to <8 x i32>
; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x i32> [[OP]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x i32> [[OP]], <8 x i32> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: store <4 x i32> [[SPLIT]], ptr [[OUT:%.*]], align 4
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, ptr [[OUT]], i64 4
; CHECK-NEXT: store <4 x i32> [[SPLIT1]], ptr [[VEC_GEP]], align 4
; CHECK-NEXT: ret void
;
%inv = load <4 x double>, ptr %in
%op = bitcast <4 x double> %inv to <8 x i32>
call void @llvm.matrix.column.major.store(<8 x i32> %op, ptr %out, i64 4, i1 false, i32 4, i32 2)
ret void
}

define void @bitcast_2x2_i256_to_v4i64(ptr %in, ptr %out) {
; CHECK-LABEL: @bitcast_2x2_i256_to_v4i64(
; CHECK-NEXT: [[INV:%.*]] = load i256, ptr [[IN:%.*]], align 4
; CHECK-NEXT: [[OP:%.*]] = bitcast i256 [[INV]] to <4 x double>
; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x double> [[OP]], <4 x double> poison, <2 x i32> <i32 0, i32 1>
; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <4 x double> [[OP]], <4 x double> poison, <2 x i32> <i32 2, i32 3>
; CHECK-NEXT: store <2 x double> [[SPLIT]], ptr [[OUT:%.*]], align 8
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[OUT]], i64 2
; CHECK-NEXT: store <2 x double> [[SPLIT1]], ptr [[VEC_GEP]], align 8
; CHECK-NEXT: ret void
;
%inv = load i256, ptr %in
%op = bitcast i256 %inv to <4 x double>
call void @llvm.matrix.column.major.store(<4 x double> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
ret void
}

define void @bitcast_2x2_4i64_to_i256(ptr %in, ptr %out) {
; CHECK-LABEL: @bitcast_2x2_4i64_to_i256(
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x double>, ptr [[IN:%.*]], align 8
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN]], i64 2
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 8
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x double> [[COL_LOAD]], <2 x double> [[COL_LOAD1]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[OP:%.*]] = bitcast <4 x double> [[TMP1]] to i256
; CHECK-NEXT: store i256 [[OP]], ptr [[OUT:%.*]], align 4
; CHECK-NEXT: ret void
;
%inv = call <4 x double> @llvm.matrix.column.major.load(ptr %in, i64 2, i1 false, i32 2, i32 2)
%op = bitcast <4 x double> %inv to i256
store i256 %op, ptr %out
ret void
}
Loading