Skip to content

Commit

Permalink
Fixed Decimal 128 Multiplication
Browse files Browse the repository at this point in the history
  • Loading branch information
razajafri committed Dec 7, 2023
1 parent 6b65d39 commit 175dc8d
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import ai.rapids.cudf.ast.BinaryOperator
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.shims.{GpuTypeShims, ShimExpression, SparkShimImpl}
import com.nvidia.spark.rapids.shims.{DecimalUtilShims, GpuTypeShims, ShimExpression, SparkShimImpl}

import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.{ComplexTypeMergingExpression, ExpectsInputTypes, Expression, NullIntolerant}
Expand Down Expand Up @@ -452,7 +452,7 @@ trait GpuDecimalMultiplyBase extends GpuExpression {
rhs.getBase.castTo(DType.create(DType.DTypeEnum.DECIMAL128, rhs.getBase.getType.getScale))
}
withResource(castRhs) { castRhs =>
com.nvidia.spark.rapids.jni.DecimalUtils.multiply128(castLhs, castRhs, -dataType.scale)
DecimalUtilShims.multiply128(castLhs, castRhs, -dataType.scale)
}
}
val retCol = withResource(retTab) { retTab =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "311"}
{"spark": "312"}
{"spark": "313"}
{"spark": "320"}
{"spark": "321"}
{"spark": "321cdh"}
{"spark": "321db"}
{"spark": "322"}
{"spark": "323"}
{"spark": "324"}
{"spark": "330"}
{"spark": "330cdh"}
{"spark": "330db"}
{"spark": "331"}
{"spark": "332"}
{"spark": "332cdh"}
{"spark": "332db"}
{"spark": "333"}
{"spark": "340"}
{"spark": "341"}
{"spark": "341db"}
{"spark": "350"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

import ai.rapids.cudf.{ColumnView, Table}
import com.nvidia.spark.rapids.jni.{DecimalUtils => du}

object DecimalUtilShims {
def multiply128(castLhs: ColumnView, castRhs: ColumnView, scale: Int): Table = {
du.multiply128(castLhs, castRhs, scale)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "351"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

import ai.rapids.cudf.{ColumnView, Table}
import com.nvidia.spark.rapids.jni.{DecimalUtils => du}

object DecimalUtilShims {
def multiply128(castLhs: ColumnView, castRhs: ColumnView, scale: Int): Table = {
du.mul128(castLhs, castRhs, scale)
}
}

0 comments on commit 175dc8d

Please sign in to comment.