@@ -388,8 +388,10 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
388
388
389
389
// Set DAG combine for 'LSX' feature.
390
390
391
- if (Subtarget.hasExtLSX ())
391
+ if (Subtarget.hasExtLSX ()) {
392
392
setTargetDAGCombine (ISD::INTRINSIC_WO_CHAIN);
393
+ setTargetDAGCombine (ISD::BITCAST);
394
+ }
393
395
394
396
// Compute derived properties from the register classes.
395
397
computeRegisterProperties (Subtarget.getRegisterInfo ());
@@ -4286,6 +4288,94 @@ static SDValue performSRLCombine(SDNode *N, SelectionDAG &DAG,
4286
4288
return SDValue ();
4287
4289
}
4288
4290
4291
+ static SDValue performBITCASTCombine (SDNode *N, SelectionDAG &DAG,
4292
+ TargetLowering::DAGCombinerInfo &DCI,
4293
+ const LoongArchSubtarget &Subtarget) {
4294
+ SDLoc DL (N);
4295
+ EVT VT = N->getValueType (0 );
4296
+ SDValue Src = N->getOperand (0 );
4297
+ EVT SrcVT = Src.getValueType ();
4298
+
4299
+ if (!DCI.isBeforeLegalizeOps ())
4300
+ return SDValue ();
4301
+
4302
+ if (!SrcVT.isSimple () || SrcVT.getScalarType () != MVT::i1)
4303
+ return SDValue ();
4304
+
4305
+ if (Src.getOpcode () != ISD::SETCC || !Src.hasOneUse ())
4306
+ return SDValue ();
4307
+
4308
+ bool UseLASX;
4309
+ EVT CmpVT = Src.getOperand (0 ).getValueType ();
4310
+ EVT EltVT = CmpVT.getVectorElementType ();
4311
+ if (Subtarget.hasExtLSX () && CmpVT.getSizeInBits () <= 128 )
4312
+ UseLASX = false ;
4313
+ else if (Subtarget.has32S () && Subtarget.hasExtLASX () &&
4314
+ CmpVT.getSizeInBits () <= 256 )
4315
+ UseLASX = true ;
4316
+ else
4317
+ return SDValue ();
4318
+
4319
+ unsigned ISD = ISD::DELETED_NODE;
4320
+ SDValue SrcN1 = Src.getOperand (1 );
4321
+ switch (cast<CondCodeSDNode>(Src.getOperand (2 ))->get ()) {
4322
+ default :
4323
+ return SDValue ();
4324
+ case ISD::SETEQ:
4325
+ if (EltVT == MVT::i8 ) {
4326
+ // x == 0 => not (vmsknez.b x)
4327
+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()))
4328
+ ISD = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
4329
+ // x == -1 => vmsknez.b x
4330
+ else if (ISD::isBuildVectorAllOnes (SrcN1.getNode ()))
4331
+ ISD = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
4332
+ }
4333
+ break ;
4334
+ case ISD::SETGT:
4335
+ // x > -1 => vmskgez.b x
4336
+ if (ISD::isBuildVectorAllOnes (SrcN1.getNode ()) && EltVT == MVT::i8 )
4337
+ ISD = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4338
+ break ;
4339
+ case ISD::SETGE:
4340
+ // x >= 0 => vmskgez.b x
4341
+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4342
+ ISD = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4343
+ break ;
4344
+ case ISD::SETLT:
4345
+ // x < 0 => vmskltz.{b,h,w,d} x
4346
+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) &&
4347
+ (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4348
+ EltVT == MVT::i64 ))
4349
+ ISD = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4350
+ break ;
4351
+ case ISD::SETLE:
4352
+ // x <= -1 => vmskltz.{b,h,w,d} x
4353
+ if (ISD::isBuildVectorAllOnes (SrcN1.getNode ()) &&
4354
+ (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4355
+ EltVT == MVT::i64 ))
4356
+ ISD = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4357
+ break ;
4358
+ case ISD::SETNE:
4359
+ if (EltVT == MVT::i8 ) {
4360
+ // x != 0 => vmsknez.b x
4361
+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()))
4362
+ ISD = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
4363
+ // x != -1 => not (vmsknez.b x)
4364
+ else if (ISD::isBuildVectorAllOnes (SrcN1.getNode ()))
4365
+ ISD = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
4366
+ }
4367
+ break ;
4368
+ }
4369
+
4370
+ if (ISD == ISD::DELETED_NODE)
4371
+ return SDValue ();
4372
+
4373
+ SDValue V = DAG.getNode (ISD, DL, MVT::i64 , Src.getOperand (0 ));
4374
+ EVT T = EVT::getIntegerVT (*DAG.getContext (), SrcVT.getVectorNumElements ());
4375
+ V = DAG.getZExtOrTrunc (V, DL, T);
4376
+ return DAG.getBitcast (VT, V);
4377
+ }
4378
+
4289
4379
static SDValue performORCombine (SDNode *N, SelectionDAG &DAG,
4290
4380
TargetLowering::DAGCombinerInfo &DCI,
4291
4381
const LoongArchSubtarget &Subtarget) {
@@ -5303,6 +5393,8 @@ SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
5303
5393
return performSETCCCombine (N, DAG, DCI, Subtarget);
5304
5394
case ISD::SRL:
5305
5395
return performSRLCombine (N, DAG, DCI, Subtarget);
5396
+ case ISD::BITCAST:
5397
+ return performBITCASTCombine (N, DAG, DCI, Subtarget);
5306
5398
case LoongArchISD::BITREV_W:
5307
5399
return performBITREV_WCombine (N, DAG, DCI, Subtarget);
5308
5400
case ISD::INTRINSIC_WO_CHAIN:
@@ -5589,6 +5681,120 @@ static MachineBasicBlock *emitPseudoCTPOP(MachineInstr &MI,
5589
5681
return BB;
5590
5682
}
5591
5683
5684
+ static MachineBasicBlock *
5685
+ emitPseudoVMSKCOND (MachineInstr &MI, MachineBasicBlock *BB,
5686
+ const LoongArchSubtarget &Subtarget) {
5687
+ const TargetInstrInfo *TII = Subtarget.getInstrInfo ();
5688
+ const TargetRegisterClass *RC = &LoongArch::LSX128RegClass;
5689
+ const LoongArchRegisterInfo *TRI = Subtarget.getRegisterInfo ();
5690
+ MachineRegisterInfo &MRI = BB->getParent ()->getRegInfo ();
5691
+ Register Dst = MI.getOperand (0 ).getReg ();
5692
+ Register Src = MI.getOperand (1 ).getReg ();
5693
+ DebugLoc DL = MI.getDebugLoc ();
5694
+ unsigned EleBits = 8 ;
5695
+ unsigned NotOpc = 0 ;
5696
+ unsigned MskOpc;
5697
+
5698
+ switch (MI.getOpcode ()) {
5699
+ default :
5700
+ llvm_unreachable (" Unexpected opcode" );
5701
+ case LoongArch::PseudoVMSKLTZ_B:
5702
+ MskOpc = LoongArch::VMSKLTZ_B;
5703
+ break ;
5704
+ case LoongArch::PseudoVMSKLTZ_H:
5705
+ MskOpc = LoongArch::VMSKLTZ_H;
5706
+ EleBits = 16 ;
5707
+ break ;
5708
+ case LoongArch::PseudoVMSKLTZ_W:
5709
+ MskOpc = LoongArch::VMSKLTZ_W;
5710
+ EleBits = 32 ;
5711
+ break ;
5712
+ case LoongArch::PseudoVMSKLTZ_D:
5713
+ MskOpc = LoongArch::VMSKLTZ_D;
5714
+ EleBits = 64 ;
5715
+ break ;
5716
+ case LoongArch::PseudoVMSKGEZ_B:
5717
+ MskOpc = LoongArch::VMSKGEZ_B;
5718
+ break ;
5719
+ case LoongArch::PseudoVMSKEQZ_B:
5720
+ MskOpc = LoongArch::VMSKNZ_B;
5721
+ NotOpc = LoongArch::VNOR_V;
5722
+ break ;
5723
+ case LoongArch::PseudoVMSKNEZ_B:
5724
+ MskOpc = LoongArch::VMSKNZ_B;
5725
+ break ;
5726
+ case LoongArch::PseudoXVMSKLTZ_B:
5727
+ MskOpc = LoongArch::XVMSKLTZ_B;
5728
+ RC = &LoongArch::LASX256RegClass;
5729
+ break ;
5730
+ case LoongArch::PseudoXVMSKLTZ_H:
5731
+ MskOpc = LoongArch::XVMSKLTZ_H;
5732
+ RC = &LoongArch::LASX256RegClass;
5733
+ EleBits = 16 ;
5734
+ break ;
5735
+ case LoongArch::PseudoXVMSKLTZ_W:
5736
+ MskOpc = LoongArch::XVMSKLTZ_W;
5737
+ RC = &LoongArch::LASX256RegClass;
5738
+ EleBits = 32 ;
5739
+ break ;
5740
+ case LoongArch::PseudoXVMSKLTZ_D:
5741
+ MskOpc = LoongArch::XVMSKLTZ_D;
5742
+ RC = &LoongArch::LASX256RegClass;
5743
+ EleBits = 64 ;
5744
+ break ;
5745
+ case LoongArch::PseudoXVMSKGEZ_B:
5746
+ MskOpc = LoongArch::XVMSKGEZ_B;
5747
+ RC = &LoongArch::LASX256RegClass;
5748
+ break ;
5749
+ case LoongArch::PseudoXVMSKEQZ_B:
5750
+ MskOpc = LoongArch::XVMSKNZ_B;
5751
+ NotOpc = LoongArch::XVNOR_V;
5752
+ RC = &LoongArch::LASX256RegClass;
5753
+ break ;
5754
+ case LoongArch::PseudoXVMSKNEZ_B:
5755
+ MskOpc = LoongArch::XVMSKNZ_B;
5756
+ RC = &LoongArch::LASX256RegClass;
5757
+ break ;
5758
+ }
5759
+
5760
+ Register Msk = MRI.createVirtualRegister (RC);
5761
+ if (NotOpc) {
5762
+ Register Tmp = MRI.createVirtualRegister (RC);
5763
+ BuildMI (*BB, MI, DL, TII->get (MskOpc), Tmp).addReg (Src);
5764
+ BuildMI (*BB, MI, DL, TII->get (NotOpc), Msk)
5765
+ .addReg (Tmp, RegState::Kill)
5766
+ .addReg (Tmp, RegState::Kill);
5767
+ } else {
5768
+ BuildMI (*BB, MI, DL, TII->get (MskOpc), Msk).addReg (Src);
5769
+ }
5770
+
5771
+ if (TRI->getRegSizeInBits (*RC) > 128 ) {
5772
+ Register Lo = MRI.createVirtualRegister (&LoongArch::GPRRegClass);
5773
+ Register Hi = MRI.createVirtualRegister (&LoongArch::GPRRegClass);
5774
+ BuildMI (*BB, MI, DL, TII->get (LoongArch::XVPICKVE2GR_WU), Lo)
5775
+ .addReg (Msk, RegState::Kill)
5776
+ .addImm (0 );
5777
+ BuildMI (*BB, MI, DL, TII->get (LoongArch::XVPICKVE2GR_WU), Hi)
5778
+ .addReg (Msk, RegState::Kill)
5779
+ .addImm (4 );
5780
+ BuildMI (*BB, MI, DL,
5781
+ TII->get (Subtarget.is64Bit () ? LoongArch::BSTRINS_D
5782
+ : LoongArch::BSTRINS_W),
5783
+ Dst)
5784
+ .addReg (Lo, RegState::Kill)
5785
+ .addReg (Hi, RegState::Kill)
5786
+ .addImm (256 / EleBits - 1 )
5787
+ .addImm (128 / EleBits);
5788
+ } else {
5789
+ BuildMI (*BB, MI, DL, TII->get (LoongArch::VPICKVE2GR_HU), Dst)
5790
+ .addReg (Msk, RegState::Kill)
5791
+ .addImm (0 );
5792
+ }
5793
+
5794
+ MI.eraseFromParent ();
5795
+ return BB;
5796
+ }
5797
+
5592
5798
static bool isSelectPseudo (MachineInstr &MI) {
5593
5799
switch (MI.getOpcode ()) {
5594
5800
default :
@@ -5795,6 +6001,21 @@ MachineBasicBlock *LoongArchTargetLowering::EmitInstrWithCustomInserter(
5795
6001
return emitPseudoXVINSGR2VR (MI, BB, Subtarget);
5796
6002
case LoongArch::PseudoCTPOP:
5797
6003
return emitPseudoCTPOP (MI, BB, Subtarget);
6004
+ case LoongArch::PseudoVMSKLTZ_B:
6005
+ case LoongArch::PseudoVMSKLTZ_H:
6006
+ case LoongArch::PseudoVMSKLTZ_W:
6007
+ case LoongArch::PseudoVMSKLTZ_D:
6008
+ case LoongArch::PseudoVMSKGEZ_B:
6009
+ case LoongArch::PseudoVMSKEQZ_B:
6010
+ case LoongArch::PseudoVMSKNEZ_B:
6011
+ case LoongArch::PseudoXVMSKLTZ_B:
6012
+ case LoongArch::PseudoXVMSKLTZ_H:
6013
+ case LoongArch::PseudoXVMSKLTZ_W:
6014
+ case LoongArch::PseudoXVMSKLTZ_D:
6015
+ case LoongArch::PseudoXVMSKGEZ_B:
6016
+ case LoongArch::PseudoXVMSKEQZ_B:
6017
+ case LoongArch::PseudoXVMSKNEZ_B:
6018
+ return emitPseudoVMSKCOND (MI, BB, Subtarget);
5798
6019
case TargetOpcode::STATEPOINT:
5799
6020
// STATEPOINT is a pseudo instruction which has no implicit defs/uses
5800
6021
// while bl call instruction (where statepoint will be lowered at the
@@ -5916,6 +6137,14 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
5916
6137
NODE_NAME_CASE (VBSLL)
5917
6138
NODE_NAME_CASE (VBSRL)
5918
6139
NODE_NAME_CASE (VLDREPL)
6140
+ NODE_NAME_CASE (VMSKLTZ)
6141
+ NODE_NAME_CASE (VMSKGEZ)
6142
+ NODE_NAME_CASE (VMSKEQZ)
6143
+ NODE_NAME_CASE (VMSKNEZ)
6144
+ NODE_NAME_CASE (XVMSKLTZ)
6145
+ NODE_NAME_CASE (XVMSKGEZ)
6146
+ NODE_NAME_CASE (XVMSKEQZ)
6147
+ NODE_NAME_CASE (XVMSKNEZ)
5919
6148
}
5920
6149
#undef NODE_NAME_CASE
5921
6150
return nullptr ;
0 commit comments