Skip to content

[LoongArch] Lower vector select mask generation to [X]VMSK{LT,GE,NE}Z if possible #142109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 230 additions & 1 deletion llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,10 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,

// Set DAG combine for 'LSX' feature.

if (Subtarget.hasExtLSX())
if (Subtarget.hasExtLSX()) {
setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
setTargetDAGCombine(ISD::BITCAST);
}

// Compute derived properties from the register classes.
computeRegisterProperties(Subtarget.getRegisterInfo());
Expand Down Expand Up @@ -4286,6 +4288,94 @@ static SDValue performSRLCombine(SDNode *N, SelectionDAG &DAG,
return SDValue();
}

static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const LoongArchSubtarget &Subtarget) {
SDLoc DL(N);
EVT VT = N->getValueType(0);
SDValue Src = N->getOperand(0);
EVT SrcVT = Src.getValueType();

if (!DCI.isBeforeLegalizeOps())
return SDValue();

if (!SrcVT.isSimple() || SrcVT.getScalarType() != MVT::i1)
return SDValue();

if (Src.getOpcode() != ISD::SETCC || !Src.hasOneUse())
return SDValue();

bool UseLASX;
EVT CmpVT = Src.getOperand(0).getValueType();
EVT EltVT = CmpVT.getVectorElementType();
if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() <= 128)
UseLASX = false;
else if (Subtarget.has32S() && Subtarget.hasExtLASX() &&
CmpVT.getSizeInBits() <= 256)
UseLASX = true;
else
return SDValue();

unsigned ISD = ISD::DELETED_NODE;
SDValue SrcN1 = Src.getOperand(1);
switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) {
default:
return SDValue();
case ISD::SETEQ:
if (EltVT == MVT::i8) {
// x == 0 => not (vmsknez.b x)
if (ISD::isBuildVectorAllZeros(SrcN1.getNode()))
ISD = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
// x == -1 => vmsknez.b x
else if (ISD::isBuildVectorAllOnes(SrcN1.getNode()))
ISD = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
}
break;
case ISD::SETGT:
// x > -1 => vmskgez.b x
if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8)
ISD = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
break;
case ISD::SETGE:
// x >= 0 => vmskgez.b x
if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
ISD = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
break;
case ISD::SETLT:
// x < 0 => vmskltz.{b,h,w,d} x
if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) &&
(EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
EltVT == MVT::i64))
ISD = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
break;
case ISD::SETLE:
// x <= -1 => vmskltz.{b,h,w,d} x
if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) &&
(EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
EltVT == MVT::i64))
ISD = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
break;
case ISD::SETNE:
if (EltVT == MVT::i8) {
// x != 0 => vmsknez.b x
if (ISD::isBuildVectorAllZeros(SrcN1.getNode()))
ISD = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
// x != -1 => not (vmsknez.b x)
else if (ISD::isBuildVectorAllOnes(SrcN1.getNode()))
ISD = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
}
break;
}

if (ISD == ISD::DELETED_NODE)
return SDValue();

SDValue V = DAG.getNode(ISD, DL, MVT::i64, Src.getOperand(0));
EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements());
V = DAG.getZExtOrTrunc(V, DL, T);
return DAG.getBitcast(VT, V);
}

static SDValue performORCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const LoongArchSubtarget &Subtarget) {
Expand Down Expand Up @@ -5303,6 +5393,8 @@ SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
return performSETCCCombine(N, DAG, DCI, Subtarget);
case ISD::SRL:
return performSRLCombine(N, DAG, DCI, Subtarget);
case ISD::BITCAST:
return performBITCASTCombine(N, DAG, DCI, Subtarget);
case LoongArchISD::BITREV_W:
return performBITREV_WCombine(N, DAG, DCI, Subtarget);
case ISD::INTRINSIC_WO_CHAIN:
Expand Down Expand Up @@ -5589,6 +5681,120 @@ static MachineBasicBlock *emitPseudoCTPOP(MachineInstr &MI,
return BB;
}

static MachineBasicBlock *
emitPseudoVMSKCOND(MachineInstr &MI, MachineBasicBlock *BB,
const LoongArchSubtarget &Subtarget) {
const TargetInstrInfo *TII = Subtarget.getInstrInfo();
const TargetRegisterClass *RC = &LoongArch::LSX128RegClass;
const LoongArchRegisterInfo *TRI = Subtarget.getRegisterInfo();
MachineRegisterInfo &MRI = BB->getParent()->getRegInfo();
Register Dst = MI.getOperand(0).getReg();
Register Src = MI.getOperand(1).getReg();
DebugLoc DL = MI.getDebugLoc();
unsigned EleBits = 8;
unsigned NotOpc = 0;
unsigned MskOpc;

switch (MI.getOpcode()) {
default:
llvm_unreachable("Unexpected opcode");
case LoongArch::PseudoVMSKLTZ_B:
MskOpc = LoongArch::VMSKLTZ_B;
break;
case LoongArch::PseudoVMSKLTZ_H:
MskOpc = LoongArch::VMSKLTZ_H;
EleBits = 16;
break;
case LoongArch::PseudoVMSKLTZ_W:
MskOpc = LoongArch::VMSKLTZ_W;
EleBits = 32;
break;
case LoongArch::PseudoVMSKLTZ_D:
MskOpc = LoongArch::VMSKLTZ_D;
EleBits = 64;
break;
case LoongArch::PseudoVMSKGEZ_B:
MskOpc = LoongArch::VMSKGEZ_B;
break;
case LoongArch::PseudoVMSKEQZ_B:
MskOpc = LoongArch::VMSKNZ_B;
NotOpc = LoongArch::VNOR_V;
break;
case LoongArch::PseudoVMSKNEZ_B:
MskOpc = LoongArch::VMSKNZ_B;
break;
case LoongArch::PseudoXVMSKLTZ_B:
MskOpc = LoongArch::XVMSKLTZ_B;
RC = &LoongArch::LASX256RegClass;
break;
case LoongArch::PseudoXVMSKLTZ_H:
MskOpc = LoongArch::XVMSKLTZ_H;
RC = &LoongArch::LASX256RegClass;
EleBits = 16;
break;
case LoongArch::PseudoXVMSKLTZ_W:
MskOpc = LoongArch::XVMSKLTZ_W;
RC = &LoongArch::LASX256RegClass;
EleBits = 32;
break;
case LoongArch::PseudoXVMSKLTZ_D:
MskOpc = LoongArch::XVMSKLTZ_D;
RC = &LoongArch::LASX256RegClass;
EleBits = 64;
break;
case LoongArch::PseudoXVMSKGEZ_B:
MskOpc = LoongArch::XVMSKGEZ_B;
RC = &LoongArch::LASX256RegClass;
break;
case LoongArch::PseudoXVMSKEQZ_B:
MskOpc = LoongArch::XVMSKNZ_B;
NotOpc = LoongArch::XVNOR_V;
RC = &LoongArch::LASX256RegClass;
break;
case LoongArch::PseudoXVMSKNEZ_B:
MskOpc = LoongArch::XVMSKNZ_B;
RC = &LoongArch::LASX256RegClass;
break;
}

Register Msk = MRI.createVirtualRegister(RC);
if (NotOpc) {
Register Tmp = MRI.createVirtualRegister(RC);
BuildMI(*BB, MI, DL, TII->get(MskOpc), Tmp).addReg(Src);
BuildMI(*BB, MI, DL, TII->get(NotOpc), Msk)
.addReg(Tmp, RegState::Kill)
.addReg(Tmp, RegState::Kill);
} else {
BuildMI(*BB, MI, DL, TII->get(MskOpc), Msk).addReg(Src);
}

if (TRI->getRegSizeInBits(*RC) > 128) {
Register Lo = MRI.createVirtualRegister(&LoongArch::GPRRegClass);
Register Hi = MRI.createVirtualRegister(&LoongArch::GPRRegClass);
BuildMI(*BB, MI, DL, TII->get(LoongArch::XVPICKVE2GR_WU), Lo)
.addReg(Msk, RegState::Kill)
.addImm(0);
BuildMI(*BB, MI, DL, TII->get(LoongArch::XVPICKVE2GR_WU), Hi)
.addReg(Msk, RegState::Kill)
.addImm(4);
BuildMI(*BB, MI, DL,
TII->get(Subtarget.is64Bit() ? LoongArch::BSTRINS_D
: LoongArch::BSTRINS_W),
Dst)
.addReg(Lo, RegState::Kill)
.addReg(Hi, RegState::Kill)
.addImm(256 / EleBits - 1)
.addImm(128 / EleBits);
} else {
BuildMI(*BB, MI, DL, TII->get(LoongArch::VPICKVE2GR_HU), Dst)
.addReg(Msk, RegState::Kill)
.addImm(0);
}

MI.eraseFromParent();
return BB;
}

static bool isSelectPseudo(MachineInstr &MI) {
switch (MI.getOpcode()) {
default:
Expand Down Expand Up @@ -5795,6 +6001,21 @@ MachineBasicBlock *LoongArchTargetLowering::EmitInstrWithCustomInserter(
return emitPseudoXVINSGR2VR(MI, BB, Subtarget);
case LoongArch::PseudoCTPOP:
return emitPseudoCTPOP(MI, BB, Subtarget);
case LoongArch::PseudoVMSKLTZ_B:
case LoongArch::PseudoVMSKLTZ_H:
case LoongArch::PseudoVMSKLTZ_W:
case LoongArch::PseudoVMSKLTZ_D:
case LoongArch::PseudoVMSKGEZ_B:
case LoongArch::PseudoVMSKEQZ_B:
case LoongArch::PseudoVMSKNEZ_B:
case LoongArch::PseudoXVMSKLTZ_B:
case LoongArch::PseudoXVMSKLTZ_H:
case LoongArch::PseudoXVMSKLTZ_W:
case LoongArch::PseudoXVMSKLTZ_D:
case LoongArch::PseudoXVMSKGEZ_B:
case LoongArch::PseudoXVMSKEQZ_B:
case LoongArch::PseudoXVMSKNEZ_B:
return emitPseudoVMSKCOND(MI, BB, Subtarget);
case TargetOpcode::STATEPOINT:
// STATEPOINT is a pseudo instruction which has no implicit defs/uses
// while bl call instruction (where statepoint will be lowered at the
Expand Down Expand Up @@ -5916,6 +6137,14 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VBSLL)
NODE_NAME_CASE(VBSRL)
NODE_NAME_CASE(VLDREPL)
NODE_NAME_CASE(VMSKLTZ)
NODE_NAME_CASE(VMSKGEZ)
NODE_NAME_CASE(VMSKEQZ)
NODE_NAME_CASE(VMSKNEZ)
NODE_NAME_CASE(XVMSKLTZ)
NODE_NAME_CASE(XVMSKGEZ)
NODE_NAME_CASE(XVMSKEQZ)
NODE_NAME_CASE(XVMSKNEZ)
}
#undef NODE_NAME_CASE
return nullptr;
Expand Down
12 changes: 11 additions & 1 deletion llvm/lib/Target/LoongArch/LoongArchISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,17 @@ enum NodeType : unsigned {
VBSRL,

// Scalar load broadcast to vector
VLDREPL
VLDREPL,

// Vector mask set by condition
VMSKLTZ,
VMSKGEZ,
VMSKEQZ,
VMSKNEZ,
XVMSKLTZ,
XVMSKGEZ,
XVMSKEQZ,
XVMSKNEZ,

// Intrinsic operations end =============================================
};
Expand Down
23 changes: 23 additions & 0 deletions llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@

// Target nodes.
def loongarch_xvpermi: SDNode<"LoongArchISD::XVPERMI", SDT_LoongArchV1RUimm>;
def loongarch_xvmskltz: SDNode<"LoongArchISD::XVMSKLTZ", SDT_LoongArchVMSKCOND>;
def loongarch_xvmskgez: SDNode<"LoongArchISD::XVMSKGEZ", SDT_LoongArchVMSKCOND>;
def loongarch_xvmskeqz: SDNode<"LoongArchISD::XVMSKEQZ", SDT_LoongArchVMSKCOND>;
def loongarch_xvmsknez: SDNode<"LoongArchISD::XVMSKNEZ", SDT_LoongArchVMSKCOND>;

def lasxsplati8
: PatFrag<(ops node:$e0),
Expand Down Expand Up @@ -1086,6 +1090,16 @@ def PseudoXVINSGR2VR_H
: Pseudo<(outs LASX256:$dst), (ins LASX256:$xd, GPR:$rj, uimm4:$imm)>;
} // usesCustomInserter = 1, Constraints = "$xd = $dst"

let usesCustomInserter = 1, hasSideEffects = 0, mayLoad = 0, mayStore = 0 in {
def PseudoXVMSKLTZ_B : Pseudo<(outs GPR:$rd), (ins LASX256:$vj)>;
def PseudoXVMSKLTZ_H : Pseudo<(outs GPR:$rd), (ins LASX256:$vj)>;
def PseudoXVMSKLTZ_W : Pseudo<(outs GPR:$rd), (ins LASX256:$vj)>;
def PseudoXVMSKLTZ_D : Pseudo<(outs GPR:$rd), (ins LASX256:$vj)>;
def PseudoXVMSKGEZ_B : Pseudo<(outs GPR:$rd), (ins LASX256:$vj)>;
def PseudoXVMSKEQZ_B : Pseudo<(outs GPR:$rd), (ins LASX256:$vj)>;
def PseudoXVMSKNEZ_B : Pseudo<(outs GPR:$rd), (ins LASX256:$vj)>;
} // usesCustomInserter = 1, hasSideEffects = 0, mayLoad = 0, mayStore = 0

} // Predicates = [HasExtLASX]

multiclass PatXr<SDPatternOperator OpNode, string Inst> {
Expand Down Expand Up @@ -1856,6 +1870,15 @@ def : Pat<(vt (concat_vectors LSX128:$vd, LSX128:$vj)),
defm : PatXrXr<abds, "XVABSD">;
defm : PatXrXrU<abdu, "XVABSD">;

// Vector mask set by condition
def : Pat<(loongarch_xvmskltz (v32i8 LASX256:$vj)), (PseudoXVMSKLTZ_B LASX256:$vj)>;
def : Pat<(loongarch_xvmskltz (v16i16 LASX256:$vj)), (PseudoXVMSKLTZ_H LASX256:$vj)>;
def : Pat<(loongarch_xvmskltz (v8i32 LASX256:$vj)), (PseudoXVMSKLTZ_W LASX256:$vj)>;
def : Pat<(loongarch_xvmskltz (v4i64 LASX256:$vj)), (PseudoXVMSKLTZ_D LASX256:$vj)>;
def : Pat<(loongarch_xvmskgez (v32i8 LASX256:$vj)), (PseudoXVMSKGEZ_B LASX256:$vj)>;
def : Pat<(loongarch_xvmskeqz (v32i8 LASX256:$vj)), (PseudoXVMSKEQZ_B LASX256:$vj)>;
def : Pat<(loongarch_xvmsknez (v32i8 LASX256:$vj)), (PseudoXVMSKNEZ_B LASX256:$vj)>;

} // Predicates = [HasExtLASX]

/// Intrinsic pattern
Expand Down
25 changes: 25 additions & 0 deletions llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def SDT_LoongArchVreplgr2vr : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisVec<0>, S
def SDT_LoongArchVFRECIPE : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVec<0>, SDTCisSameAs<0, 1>]>;
def SDT_LoongArchVFRSQRTE : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVec<0>, SDTCisSameAs<0, 1>]>;
def SDT_LoongArchVLDREPL : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisPtrTy<1>]>;
def SDT_LoongArchVMSKCOND : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisVec<1>]>;

// Target nodes.
def loongarch_vreplve : SDNode<"LoongArchISD::VREPLVE", SDT_LoongArchVreplve>;
Expand Down Expand Up @@ -74,6 +75,11 @@ def loongarch_vldrepl
: SDNode<"LoongArchISD::VLDREPL",
SDT_LoongArchVLDREPL, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;

def loongarch_vmskltz: SDNode<"LoongArchISD::VMSKLTZ", SDT_LoongArchVMSKCOND>;
def loongarch_vmskgez: SDNode<"LoongArchISD::VMSKGEZ", SDT_LoongArchVMSKCOND>;
def loongarch_vmskeqz: SDNode<"LoongArchISD::VMSKEQZ", SDT_LoongArchVMSKCOND>;
def loongarch_vmsknez: SDNode<"LoongArchISD::VMSKNEZ", SDT_LoongArchVMSKCOND>;

def immZExt1 : ImmLeaf<i64, [{return isUInt<1>(Imm);}]>;
def immZExt2 : ImmLeaf<i64, [{return isUInt<2>(Imm);}]>;
def immZExt3 : ImmLeaf<i64, [{return isUInt<3>(Imm);}]>;
Expand Down Expand Up @@ -1266,6 +1272,16 @@ let usesCustomInserter = 1 in
def PseudoCTPOP : Pseudo<(outs GPR:$rd), (ins GPR:$rj),
[(set GPR:$rd, (ctpop GPR:$rj))]>;

let usesCustomInserter = 1, hasSideEffects = 0, mayLoad = 0, mayStore = 0 in {
def PseudoVMSKLTZ_B : Pseudo<(outs GPR:$rd), (ins LSX128:$vj)>;
def PseudoVMSKLTZ_H : Pseudo<(outs GPR:$rd), (ins LSX128:$vj)>;
def PseudoVMSKLTZ_W : Pseudo<(outs GPR:$rd), (ins LSX128:$vj)>;
def PseudoVMSKLTZ_D : Pseudo<(outs GPR:$rd), (ins LSX128:$vj)>;
def PseudoVMSKGEZ_B : Pseudo<(outs GPR:$rd), (ins LSX128:$vj)>;
def PseudoVMSKEQZ_B : Pseudo<(outs GPR:$rd), (ins LSX128:$vj)>;
def PseudoVMSKNEZ_B : Pseudo<(outs GPR:$rd), (ins LSX128:$vj)>;
} // usesCustomInserter = 1, hasSideEffects = 0, mayLoad = 0, mayStore = 0

} // Predicates = [HasExtLSX]

multiclass PatVr<SDPatternOperator OpNode, string Inst> {
Expand Down Expand Up @@ -2050,6 +2066,15 @@ def : Pat<(f64 f64imm_vldi:$in),
defm : PatVrVr<abds, "VABSD">;
defm : PatVrVrU<abdu, "VABSD">;

// Vector mask set by condition
def : Pat<(loongarch_vmskltz (v16i8 LSX128:$vj)), (PseudoVMSKLTZ_B LSX128:$vj)>;
def : Pat<(loongarch_vmskltz (v8i16 LSX128:$vj)), (PseudoVMSKLTZ_H LSX128:$vj)>;
def : Pat<(loongarch_vmskltz (v4i32 LSX128:$vj)), (PseudoVMSKLTZ_W LSX128:$vj)>;
def : Pat<(loongarch_vmskltz (v2i64 LSX128:$vj)), (PseudoVMSKLTZ_D LSX128:$vj)>;
def : Pat<(loongarch_vmskgez (v16i8 LSX128:$vj)), (PseudoVMSKGEZ_B LSX128:$vj)>;
def : Pat<(loongarch_vmskeqz (v16i8 LSX128:$vj)), (PseudoVMSKEQZ_B LSX128:$vj)>;
def : Pat<(loongarch_vmsknez (v16i8 LSX128:$vj)), (PseudoVMSKNEZ_B LSX128:$vj)>;

} // Predicates = [HasExtLSX]

/// Intrinsic pattern
Expand Down
Loading