Skip to content

[NVPTX][NFC] Refactor and cleanup NVPTXISelLowering call lowering 2/n #137666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 25, 2025

Conversation

AlexMaclean
Copy link
Member

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Apr 28, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

Patch is 60.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/137666.diff

3 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+418-547)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+5-5)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+5-28)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index c41741ed10232..b21635f7caf04 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -343,33 +343,35 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
 /// and promote them to a larger size if they're not.
 ///
 /// The promoted type is placed in \p PromoteVT if the function returns true.
-static bool PromoteScalarIntegerPTX(const EVT &VT, MVT *PromotedVT) {
+static std::optional<MVT> PromoteScalarIntegerPTX(const EVT &VT) {
   if (VT.isScalarInteger()) {
+    MVT PromotedVT;
     switch (PowerOf2Ceil(VT.getFixedSizeInBits())) {
     default:
       llvm_unreachable(
           "Promotion is not suitable for scalars of size larger than 64-bits");
     case 1:
-      *PromotedVT = MVT::i1;
+      PromotedVT = MVT::i1;
       break;
     case 2:
     case 4:
     case 8:
-      *PromotedVT = MVT::i8;
+      PromotedVT = MVT::i8;
       break;
     case 16:
-      *PromotedVT = MVT::i16;
+      PromotedVT = MVT::i16;
       break;
     case 32:
-      *PromotedVT = MVT::i32;
+      PromotedVT = MVT::i32;
       break;
     case 64:
-      *PromotedVT = MVT::i64;
+      PromotedVT = MVT::i64;
       break;
     }
-    return EVT(*PromotedVT) != VT;
+    if (VT != PromotedVT)
+      return PromotedVT;
   }
-  return false;
+  return std::nullopt;
 }
 
 // Check whether we can merge loads/stores of some of the pieces of a
@@ -426,16 +428,6 @@ static unsigned CanMergeParamLoadStoresStartingAt(
   return NumElts;
 }
 
-// Flags for tracking per-element vectorization state of loads/stores
-// of a flattened function parameter or return value.
-enum ParamVectorizationFlags {
-  PVF_INNER = 0x0, // Middle elements of a vector.
-  PVF_FIRST = 0x1, // First element of the vector.
-  PVF_LAST = 0x2,  // Last element of the vector.
-  // Scalar is effectively a 1-element vector.
-  PVF_SCALAR = PVF_FIRST | PVF_LAST
-};
-
 // Computes whether and how we can vectorize the loads/stores of a
 // flattened function parameter or return value.
 //
@@ -444,52 +436,39 @@ enum ParamVectorizationFlags {
 // of the same size as ValueVTs indicating how each piece should be
 // loaded/stored (i.e. as a scalar, or as part of a vector
 // load/store).
-static SmallVector<ParamVectorizationFlags, 16>
+static SmallVector<unsigned, 16>
 VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
                      const SmallVectorImpl<uint64_t> &Offsets,
                      Align ParamAlignment, bool IsVAArg = false) {
   // Set vector size to match ValueVTs and mark all elements as
   // scalars by default.
-  SmallVector<ParamVectorizationFlags, 16> VectorInfo;
-  VectorInfo.assign(ValueVTs.size(), PVF_SCALAR);
+  SmallVector<unsigned, 16> VectorInfo;
 
-  if (IsVAArg)
+  if (IsVAArg) {
+    VectorInfo.assign(ValueVTs.size(), 1);
     return VectorInfo;
+  }
 
-  // Check what we can vectorize using 128/64/32-bit accesses.
-  for (int I = 0, E = ValueVTs.size(); I != E; ++I) {
-    // Skip elements we've already processed.
-    assert(VectorInfo[I] == PVF_SCALAR && "Unexpected vector info state.");
-    for (unsigned AccessSize : {16, 8, 4, 2}) {
-      unsigned NumElts = CanMergeParamLoadStoresStartingAt(
+  const auto GetNumElts = [&](unsigned I) -> unsigned {
+    for (const unsigned AccessSize : {16, 8, 4, 2}) {
+      const unsigned NumElts = CanMergeParamLoadStoresStartingAt(
           I, AccessSize, ValueVTs, Offsets, ParamAlignment);
-      // Mark vectorized elements.
-      switch (NumElts) {
-      default:
-        llvm_unreachable("Unexpected return value");
-      case 1:
-        // Can't vectorize using this size, try next smaller size.
-        continue;
-      case 2:
-        assert(I + 1 < E && "Not enough elements.");
-        VectorInfo[I] = PVF_FIRST;
-        VectorInfo[I + 1] = PVF_LAST;
-        I += 1;
-        break;
-      case 4:
-        assert(I + 3 < E && "Not enough elements.");
-        VectorInfo[I] = PVF_FIRST;
-        VectorInfo[I + 1] = PVF_INNER;
-        VectorInfo[I + 2] = PVF_INNER;
-        VectorInfo[I + 3] = PVF_LAST;
-        I += 3;
-        break;
-      }
-      // Break out of the inner loop because we've already succeeded
-      // using largest possible AccessSize.
-      break;
+      assert((NumElts == 1 || NumElts == 2 || NumElts == 4) &&
+             "Unexpected vectorization size");
+      if (NumElts != 1)
+        return NumElts;
     }
+    return 1;
+  };
+
+  // Check what we can vectorize using 128/64/32-bit accesses.
+  for (unsigned I = 0, E = ValueVTs.size(); I != E;) {
+    const unsigned NumElts = GetNumElts(I);
+    VectorInfo.push_back(NumElts);
+    I += NumElts;
   }
+  assert(std::accumulate(VectorInfo.begin(), VectorInfo.end(), 0u) ==
+         ValueVTs.size());
   return VectorInfo;
 }
 
@@ -1165,21 +1144,24 @@ NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
 
 std::string NVPTXTargetLowering::getPrototype(
     const DataLayout &DL, Type *retTy, const ArgListTy &Args,
-    const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign retAlignment,
-    std::optional<std::pair<unsigned, const APInt &>> VAInfo,
-    const CallBase &CB, unsigned UniqueCallSite) const {
+    const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign RetAlign,
+    std::optional<std::pair<unsigned, unsigned>> VAInfo, const CallBase &CB,
+    unsigned UniqueCallSite) const {
   auto PtrVT = getPointerTy(DL);
 
   std::string Prototype;
   raw_string_ostream O(Prototype);
   O << "prototype_" << UniqueCallSite << " : .callprototype ";
 
-  if (retTy->getTypeID() == Type::VoidTyID) {
+  if (retTy->isVoidTy()) {
     O << "()";
   } else {
     O << "(";
-    if ((retTy->isFloatingPointTy() || retTy->isIntegerTy()) &&
-        !shouldPassAsArray(retTy)) {
+    if (shouldPassAsArray(retTy)) {
+      assert(RetAlign && "RetAlign must be set for non-void return types");
+      O << ".param .align " << RetAlign->value() << " .b8 _["
+        << DL.getTypeAllocSize(retTy) << "]";
+    } else if (retTy->isFloatingPointTy() || retTy->isIntegerTy()) {
       unsigned size = 0;
       if (auto *ITy = dyn_cast<IntegerType>(retTy)) {
         size = ITy->getBitWidth();
@@ -1196,9 +1178,6 @@ std::string NVPTXTargetLowering::getPrototype(
       O << ".param .b" << size << " _";
     } else if (isa<PointerType>(retTy)) {
       O << ".param .b" << PtrVT.getSizeInBits() << " _";
-    } else if (shouldPassAsArray(retTy)) {
-      O << ".param .align " << (retAlignment ? retAlignment->value() : 0)
-        << " .b8 _[" << DL.getTypeAllocSize(retTy) << "]";
     } else {
       llvm_unreachable("Unknown return type");
     }
@@ -1208,57 +1187,52 @@ std::string NVPTXTargetLowering::getPrototype(
 
   bool first = true;
 
-  unsigned NumArgs = VAInfo ? VAInfo->first : Args.size();
-  for (unsigned i = 0, OIdx = 0; i != NumArgs; ++i, ++OIdx) {
-    Type *Ty = Args[i].Ty;
+  const unsigned NumArgs = VAInfo ? VAInfo->first : Args.size();
+  auto AllOuts = ArrayRef(Outs);
+  for (const unsigned I : llvm::seq(NumArgs)) {
+    const auto ArgOuts =
+        AllOuts.take_while([I](auto O) { return O.OrigArgIndex == I; });
+    AllOuts = AllOuts.drop_front(ArgOuts.size());
+
+    Type *Ty = Args[I].Ty;
     if (!first) {
       O << ", ";
     }
     first = false;
 
-    if (!Outs[OIdx].Flags.isByVal()) {
+    if (ArgOuts[0].Flags.isByVal()) {
+      // Indirect calls need strict ABI alignment so we disable optimizations by
+      // not providing a function to optimize.
+      Type *ETy = Args[I].IndirectType;
+      Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
+      Align ParamByValAlign =
+          getFunctionByValParamAlign(/*F=*/nullptr, ETy, InitialAlign, DL);
+
+      O << ".param .align " << ParamByValAlign.value() << " .b8 _["
+        << ArgOuts[0].Flags.getByValSize() << "]";
+    } else {
       if (shouldPassAsArray(Ty)) {
         Align ParamAlign =
-            getArgumentAlignment(&CB, Ty, i + AttributeList::FirstArgIndex, DL);
-        O << ".param .align " << ParamAlign.value() << " .b8 ";
-        O << "_";
-        O << "[" << DL.getTypeAllocSize(Ty) << "]";
-        // update the index for Outs
-        SmallVector<EVT, 16> vtparts;
-        ComputeValueVTs(*this, DL, Ty, vtparts);
-        if (unsigned len = vtparts.size())
-          OIdx += len - 1;
+            getArgumentAlignment(&CB, Ty, I + AttributeList::FirstArgIndex, DL);
+        O << ".param .align " << ParamAlign.value() << " .b8 _["
+          << DL.getTypeAllocSize(Ty) << "]";
         continue;
       }
       // i8 types in IR will be i16 types in SDAG
-      assert((getValueType(DL, Ty) == Outs[OIdx].VT ||
-              (getValueType(DL, Ty) == MVT::i8 && Outs[OIdx].VT == MVT::i16)) &&
+      assert((getValueType(DL, Ty) == ArgOuts[0].VT ||
+              (getValueType(DL, Ty) == MVT::i8 && ArgOuts[0].VT == MVT::i16)) &&
              "type mismatch between callee prototype and arguments");
       // scalar type
       unsigned sz = 0;
-      if (isa<IntegerType>(Ty)) {
-        sz = cast<IntegerType>(Ty)->getBitWidth();
-        sz = promoteScalarArgumentSize(sz);
+      if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
+        sz = promoteScalarArgumentSize(ITy->getBitWidth());
       } else if (isa<PointerType>(Ty)) {
         sz = PtrVT.getSizeInBits();
       } else {
         sz = Ty->getPrimitiveSizeInBits();
       }
-      O << ".param .b" << sz << " ";
-      O << "_";
-      continue;
+      O << ".param .b" << sz << " _";
     }
-
-    // Indirect calls need strict ABI alignment so we disable optimizations by
-    // not providing a function to optimize.
-    Type *ETy = Args[i].IndirectType;
-    Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
-    Align ParamByValAlign =
-        getFunctionByValParamAlign(/*F=*/nullptr, ETy, InitialAlign, DL);
-
-    O << ".param .align " << ParamByValAlign.value() << " .b8 ";
-    O << "_";
-    O << "[" << Outs[OIdx].Flags.getByValSize() << "]";
   }
 
   if (VAInfo)
@@ -1441,6 +1415,10 @@ static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
   return MachinePointerInfo();
 }
 
+static ISD::NodeType getExtOpcode(const ISD::ArgFlagsTy &Flags) {
+  return Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
+}
+
 SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
                                        SmallVectorImpl<SDValue> &InVals) const {
 
@@ -1451,8 +1429,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
 
   SelectionDAG &DAG = CLI.DAG;
   SDLoc dl = CLI.DL;
-  SmallVectorImpl<ISD::OutputArg> &Outs = CLI.Outs;
-  SmallVectorImpl<SDValue> &OutVals = CLI.OutVals;
   SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
   SDValue Chain = CLI.Chain;
   SDValue Callee = CLI.Callee;
@@ -1462,6 +1438,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   const CallBase *CB = CLI.CB;
   const DataLayout &DL = DAG.getDataLayout();
 
+  const auto GetI32 = [&](const unsigned I) {
+    return DAG.getConstant(I, dl, MVT::i32);
+  };
+
   // Variadic arguments.
   //
   // Normally, for each argument, we declare a param scalar or a param
@@ -1479,7 +1459,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   // vararg byte array.
 
   SDValue VADeclareParam;                 // vararg byte array
-  unsigned FirstVAArg = CLI.NumFixedArgs; // position of the first variadic
+  const unsigned FirstVAArg = CLI.NumFixedArgs; // position of first variadic
   unsigned VAOffset = 0;                  // current offset in the param array
 
   const unsigned UniqueCallSite = GlobalUniqueCallSite++;
@@ -1487,7 +1467,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   Chain = DAG.getCALLSEQ_START(Chain, UniqueCallSite, 0, dl);
   SDValue InGlue = Chain.getValue(1);
 
-  unsigned ParamCount = 0;
   // Args.size() and Outs.size() need not match.
   // Outs.size() will be larger
   //   * if there is an aggregate argument with multiple fields (each field
@@ -1497,77 +1476,78 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   //     individually present in Outs.
   // So a different index should be used for indexing into Outs/OutVals.
   // See similar issue in LowerFormalArguments.
-  unsigned OIdx = 0;
+  auto AllOuts = ArrayRef(CLI.Outs);
+  auto AllOutVals = ArrayRef(CLI.OutVals);
+  assert(AllOuts.size() == AllOutVals.size() &&
+         "Outs and OutVals must be the same size");
   // Declare the .params or .reg need to pass values
   // to the function
-  for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
-    EVT VT = Outs[OIdx].VT;
-    Type *Ty = Args[i].Ty;
-    bool IsVAArg = (i >= CLI.NumFixedArgs);
-    bool IsByVal = Outs[OIdx].Flags.isByVal();
+  for (const auto [ArgI, Arg] : llvm::enumerate(Args)) {
+    const auto ArgOuts = AllOuts.take_while(
+        [ArgI = ArgI](auto O) { return O.OrigArgIndex == ArgI; });
+    const auto ArgOutVals = AllOutVals.take_front(ArgOuts.size());
+    AllOuts = AllOuts.drop_front(ArgOuts.size());
+    AllOutVals = AllOutVals.drop_front(ArgOuts.size());
+
+    const bool IsVAArg = (ArgI >= FirstVAArg);
+    const bool IsByVal = Arg.IsByVal;
 
     SmallVector<EVT, 16> VTs;
     SmallVector<uint64_t, 16> Offsets;
 
-    assert((!IsByVal || Args[i].IndirectType) &&
+    assert((!IsByVal || Arg.IndirectType) &&
            "byval arg must have indirect type");
-    Type *ETy = (IsByVal ? Args[i].IndirectType : Ty);
+    Type *ETy = (IsByVal ? Arg.IndirectType : Arg.Ty);
     ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, IsByVal ? 0 : VAOffset);
+    assert(VTs.size() == Offsets.size() && "Size mismatch");
+    assert((IsByVal || VTs.size() == ArgOuts.size()) && "Size mismatch");
 
     Align ArgAlign;
     if (IsByVal) {
       // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
       // so we don't need to worry whether it's naturally aligned or not.
       // See TargetLowering::LowerCallTo().
-      Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
+      Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
       ArgAlign = getFunctionByValParamAlign(CB->getCalledFunction(), ETy,
                                             InitialAlign, DL);
       if (IsVAArg)
         VAOffset = alignTo(VAOffset, ArgAlign);
     } else {
-      ArgAlign = getArgumentAlignment(CB, Ty, ParamCount + 1, DL);
+      ArgAlign = getArgumentAlignment(CB, Arg.Ty, ArgI + 1, DL);
     }
 
-    unsigned TypeSize =
-        (IsByVal ? Outs[OIdx].Flags.getByValSize() : DL.getTypeAllocSize(Ty));
-    SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
+    const unsigned TypeSize = DL.getTypeAllocSize(ETy);
+    assert((!IsByVal || TypeSize == ArgOuts[0].Flags.getByValSize()) &&
+           "type size mismatch");
 
-    bool NeedAlign; // Does argument declaration specify alignment?
-    const bool PassAsArray = IsByVal || shouldPassAsArray(Ty);
+    const bool PassAsArray = IsByVal || shouldPassAsArray(Arg.Ty);
     if (IsVAArg) {
-      if (ParamCount == FirstVAArg) {
-        SDValue DeclareParamOps[] = {
-            Chain, DAG.getConstant(STI.getMaxRequiredAlignment(), dl, MVT::i32),
-            DAG.getConstant(ParamCount, dl, MVT::i32),
-            DAG.getConstant(1, dl, MVT::i32), InGlue};
-        VADeclareParam = Chain = DAG.getNode(NVPTXISD::DeclareParam, dl,
-                                             DeclareParamVTs, DeclareParamOps);
+      if (ArgI == FirstVAArg) {
+        VADeclareParam = Chain =
+            DAG.getNode(NVPTXISD::DeclareParam, dl, {MVT::Other, MVT::Glue},
+                        {Chain, GetI32(STI.getMaxRequiredAlignment()),
+                         GetI32(ArgI), GetI32(1), InGlue});
       }
-      NeedAlign = PassAsArray;
     } else if (PassAsArray) {
       // declare .param .align <align> .b8 .param<n>[<size>];
-      SDValue DeclareParamOps[] = {
-          Chain, DAG.getConstant(ArgAlign.value(), dl, MVT::i32),
-          DAG.getConstant(ParamCount, dl, MVT::i32),
-          DAG.getConstant(TypeSize, dl, MVT::i32), InGlue};
-      Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
-                          DeclareParamOps);
-      NeedAlign = true;
+      Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, {MVT::Other, MVT::Glue},
+                          {Chain, GetI32(ArgAlign.value()), GetI32(ArgI),
+                           GetI32(TypeSize), InGlue});
     } else {
+      assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
       // declare .param .b<size> .param<n>;
-      if (VT.isInteger() || VT.isFloatingPoint()) {
-        // PTX ABI requires integral types to be at least 32 bits in
-        // size. FP16 is loaded/stored using i16, so it's handled
-        // here as well.
-        TypeSize = promoteScalarArgumentSize(TypeSize * 8) / 8;
-      }
-      SDValue DeclareScalarParamOps[] = {
-          Chain, DAG.getConstant(ParamCount, dl, MVT::i32),
-          DAG.getConstant(TypeSize * 8, dl, MVT::i32),
-          DAG.getConstant(0, dl, MVT::i32), InGlue};
-      Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
-                          DeclareScalarParamOps);
-      NeedAlign = false;
+
+      // PTX ABI requires integral types to be at least 32 bits in
+      // size. FP16 is loaded/stored using i16, so it's handled
+      // here as well.
+      const unsigned PromotedSize =
+          (ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint())
+              ? promoteScalarArgumentSize(TypeSize * 8)
+              : TypeSize * 8;
+
+      Chain = DAG.getNode(
+          NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
+          {Chain, GetI32(ArgI), GetI32(PromotedSize), GetI32(0), InGlue});
     }
     InGlue = Chain.getValue(1);
 
@@ -1575,196 +1555,169 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     // than 32-bits are sign extended or zero extended, depending on
     // whether they are signed or unsigned types. This case applies
     // only to scalar parameters and not to aggregate values.
-    bool ExtendIntegerParam =
-        Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty) < 32;
+    const bool ExtendIntegerParam =
+        Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
 
-    auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
-    SmallVector<SDValue, 6> StoreOperands;
-    for (const unsigned J : llvm::seq(VTs.size())) {
-      EVT EltVT = VTs[J];
-      const int CurOffset = Offsets[J];
-      MaybeAlign PartAlign;
-      if (NeedAlign)
-        PartAlign = commonAlignment(ArgAlign, CurOffset);
+    const auto GetStoredValue = [&](const unsigned I, EVT EltVT,
+                                    const Align PartAlign) {
+      SDValue StVal;
+      if (IsByVal) {
+        SDValue Ptr = ArgOutVals[0];
+        auto MPI = refinePtrAS(Ptr, DAG, DL, *this);
+        SDValue SrcAddr =
+            DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(Offsets[I]));
 
-      SDValue StVal = OutVals[OIdx];
+        StVal = DAG.getLoad(EltVT, dl, TempChain, SrcAddr, MPI, PartAlign);
+      } else {
+        StVal = ArgOutVals[I];
 
-      MVT PromotedVT;
-      if (PromoteScalarIntegerPTX(EltVT, &PromotedVT)) {
-        EltVT = EVT(PromotedVT);
-      }
-      if (PromoteScalarIntegerPTX(StVal.getValueType(), &PromotedVT)) {
-        llvm::ISD::NodeType Ext =
-            Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
-        StVal = DAG.getNode(Ext, dl, PromotedVT, StVal);
+        if (auto PromotedVT = PromoteScalarIntegerPTX(StVal.getValueType())) {
+          StVal = DAG.getNode(getExtOpcode(ArgOuts[I].Flags), dl, *PromotedVT,
+                              StVal);
+        }
       }
 
-      if (IsByVal) {
-        auto MPI = refinePtrAS(StVal, DAG, DL, *this);
-        const EVT PtrVT = StVal.getValueType();
-        SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal,
-                                      DAG.getConstant(CurOffset, dl, PtrVT));
-
-        StVal = DAG.getLoad(EltVT, dl, TempChain, SrcAddr, MPI, PartAlign);
-      } else if (ExtendIntegerParam) {
+      if (ExtendIntegerParam) {
   ...
[truncated]

Copy link

github-actions bot commented Apr 28, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/cleanup-428 branch from 70a1de7 to 43b630d Compare April 28, 2025 16:46
Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hold on a sec with merging. I've missed all the changes in NVPTXISelLowering.cpp. Thanks to GitHub for hiding the diff.

Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few nits. LGTM otherwise.
I'm actually surprised that there are no changes at all for the tests. I suspect that we may not have enough coverage.

@AlexMaclean
Copy link
Member Author

I'm actually surprised that there are no changes at all for the tests. I suspect that we may not have enough coverage.

I agree we don't have the greatest coverage, but this change is intended to be non-functional. Is there something specific that looks like it would cause changes to code-gen?

@Artem-B
Copy link
Member

Artem-B commented Apr 28, 2025

Is there something specific that looks like it would cause changes to code-gen?

Nothing specific, just that somewhat uneasy about large changes in the common code path and where we may not have particularly robust coverage. With all the refactoring and shuffling the code around I can can't say that I'm confident that I didn't miss anything important. That's where the tests should provide additional "moral support", but we have what we have.

It's OK. If something is amiss, I'll probably see it shortly after merging, on our internal tests.

Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/cleanup-428 branch from 59e0832 to 4b77338 Compare May 25, 2025 02:52
@AlexMaclean AlexMaclean merged commit 7204141 into llvm:main May 25, 2025
11 checks passed
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Jun 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants