Skip to content

Commit

Permalink
ARROW-18112: [Go] Remaining Scalar Arithmetic (#14777)
Browse files Browse the repository at this point in the history
Authored-by: Matt Topol <zotthewizard@gmail.com>
Signed-off-by: Matt Topol <zotthewizard@gmail.com>
  • Loading branch information
zeroshade committed Dec 1, 2022
1 parent 1d9f778 commit 1c8853b
Show file tree
Hide file tree
Showing 20 changed files with 3,181 additions and 160 deletions.
1 change: 1 addition & 0 deletions dev/release/rat_exclude_files.txt
Expand Up @@ -141,6 +141,7 @@ go/arrow/compute/go.sum
go/arrow/compute/datumkind_string.go
go/arrow/compute/funckind_string.go
go/arrow/compute/internal/kernels/compareoperator_string.go
go/arrow/compute/internal/kernels/roundmode_string.go
go/arrow/compute/internal/kernels/_lib/vendored/*
go/*.tmpldata
go/*.s
Expand Down
286 changes: 286 additions & 0 deletions go/arrow/compute/arithmetic.go
Expand Up @@ -25,6 +25,44 @@ import (
"github.com/apache/arrow/go/v11/arrow"
"github.com/apache/arrow/go/v11/arrow/compute/internal/exec"
"github.com/apache/arrow/go/v11/arrow/compute/internal/kernels"
"github.com/apache/arrow/go/v11/arrow/decimal128"
"github.com/apache/arrow/go/v11/arrow/decimal256"
"github.com/apache/arrow/go/v11/arrow/scalar"
)

type (
RoundOptions = kernels.RoundOptions
RoundMode = kernels.RoundMode
RoundToMultipleOptions = kernels.RoundToMultipleOptions
)

const (
// Round to nearest integer less than or equal in magnitude (aka "floor")
RoundDown = kernels.RoundDown
// Round to nearest integer greater than or equal in magnitude (aka "ceil")
RoundUp = kernels.RoundUp
// Get integral part without fractional digits (aka "trunc")
RoundTowardsZero = kernels.TowardsZero
// Round negative values with DOWN and positive values with UP
RoundTowardsInfinity = kernels.AwayFromZero
// Round ties with DOWN (aka "round half towards negative infinity")
RoundHalfDown = kernels.HalfDown
// Round ties with UP (aka "round half towards positive infinity")
RoundHalfUp = kernels.HalfUp
// Round ties with TowardsZero (aka "round half away from infinity")
RoundHalfTowardsZero = kernels.HalfTowardsZero
// Round ties with AwayFromZero (aka "round half towards infinity")
RoundHalfTowardsInfinity = kernels.HalfAwayFromZero
// Round ties to nearest even integer
RoundHalfToEven = kernels.HalfToEven
// Round ties to nearest odd integer
RoundHalfToOdd = kernels.HalfToOdd
)

var (
DefaultRoundOptions = RoundOptions{NDigits: 0, Mode: RoundHalfToEven}
DefaultRoundToMultipleOptions = RoundToMultipleOptions{
Multiple: scalar.NewFloat64Scalar(1), Mode: RoundHalfToEven}
)

type arithmeticFunction struct {
Expand Down Expand Up @@ -121,6 +159,7 @@ func (fn *arithmeticFloatingPointFunc) DispatchBest(vals ...arrow.DataType) (exe
return fn.DispatchExact(vals...)
}

// function that promotes only decimal arguments to float64
type arithmeticDecimalToFloatingPointFunc struct {
arithmeticFunction
}
Expand Down Expand Up @@ -156,6 +195,46 @@ func (fn *arithmeticDecimalToFloatingPointFunc) DispatchBest(vals ...arrow.DataT
return fn.DispatchExact(vals...)
}

// function that promotes only integer arguments to float64
type arithmeticIntegerToFloatingPointFunc struct {
arithmeticFunction
}

func (fn *arithmeticIntegerToFloatingPointFunc) Execute(ctx context.Context, opts FunctionOptions, args ...Datum) (Datum, error) {
return execInternal(ctx, fn, opts, -1, args...)
}

func (fn *arithmeticIntegerToFloatingPointFunc) DispatchBest(vals ...arrow.DataType) (exec.Kernel, error) {
if err := fn.checkArity(len(vals)); err != nil {
return nil, err
}

if err := fn.checkDecimals(vals...); err != nil {
return nil, err
}

if kn, err := fn.DispatchExact(vals...); err == nil {
return kn, nil
}

ensureDictionaryDecoded(vals...)
if len(vals) == 2 {
replaceNullWithOtherType(vals...)
}

for i, t := range vals {
if arrow.IsInteger(t.ID()) {
vals[i] = arrow.PrimitiveTypes.Float64
}
}

if dt := commonNumeric(vals...); dt != nil {
replaceTypes(dt, vals...)
}

return fn.DispatchExact(vals...)
}

var (
addDoc FunctionDoc
)
Expand Down Expand Up @@ -382,6 +461,25 @@ func RegisterScalarArithmetic(reg FunctionRegistry) {
}{
{"sqrt_unchecked", kernels.OpSqrt, decPromoteNone},
{"sqrt", kernels.OpSqrtChecked, decPromoteNone},
{"sin_unchecked", kernels.OpSin, decPromoteNone},
{"sin", kernels.OpSinChecked, decPromoteNone},
{"cos_unchecked", kernels.OpCos, decPromoteNone},
{"cos", kernels.OpCosChecked, decPromoteNone},
{"tan_unchecked", kernels.OpTan, decPromoteNone},
{"tan", kernels.OpTanChecked, decPromoteNone},
{"asin_unchecked", kernels.OpAsin, decPromoteNone},
{"asin", kernels.OpAsinChecked, decPromoteNone},
{"acos_unchecked", kernels.OpAcos, decPromoteNone},
{"acos", kernels.OpAcosChecked, decPromoteNone},
{"atan", kernels.OpAtan, decPromoteNone},
{"ln_unchecked", kernels.OpLn, decPromoteNone},
{"ln", kernels.OpLnChecked, decPromoteNone},
{"log10_unchecked", kernels.OpLog10, decPromoteNone},
{"log10", kernels.OpLog10Checked, decPromoteNone},
{"log2_unchecked", kernels.OpLog2, decPromoteNone},
{"log2", kernels.OpLog2Checked, decPromoteNone},
{"log1p_unchecked", kernels.OpLog1p, decPromoteNone},
{"log1p", kernels.OpLog1pChecked, decPromoteNone},
}

for _, o := range ops {
Expand All @@ -396,6 +494,28 @@ func RegisterScalarArithmetic(reg FunctionRegistry) {
reg.AddFunction(fn, false)
}

ops = []struct {
funcName string
op kernels.ArithmeticOp
decPromote decimalPromotion
}{
{"atan2", kernels.OpAtan2, decPromoteNone},
{"logb_unchecked", kernels.OpLogb, decPromoteNone},
{"logb", kernels.OpLogbChecked, decPromoteNone},
}

for _, o := range ops {
fn := &arithmeticFloatingPointFunc{arithmeticFunction{*NewScalarFunction(o.funcName, Binary(), addDoc), decPromoteNone}}
kns := kernels.GetArithmeticFloatingPointKernels(o.op)
for _, k := range kns {
if err := fn.AddKernel(k); err != nil {
panic(err)
}
}

reg.AddFunction(fn, false)
}

fn = &arithmeticFunction{*NewScalarFunction("sign", Unary(), addDoc), decPromoteNone}
kns = kernels.GetArithmeticUnaryFixedIntOutKernels(arrow.PrimitiveTypes.Int8, kernels.OpSign)
for _, k := range kns {
Expand Down Expand Up @@ -446,6 +566,15 @@ func RegisterScalarArithmetic(reg FunctionRegistry) {
reg.AddFunction(fn, false)
}

fn = &arithmeticFunction{*NewScalarFunction("bit_wise_not", Unary(), EmptyFuncDoc), decPromoteNone}
for _, k := range kernels.GetBitwiseUnaryKernels() {
if err := fn.AddKernel(k); err != nil {
panic(err)
}
}

reg.AddFunction(fn, false)

shiftOps := []struct {
funcName string
dir kernels.ShiftDir
Expand All @@ -467,6 +596,67 @@ func RegisterScalarArithmetic(reg FunctionRegistry) {
}
reg.AddFunction(fn, false)
}

floorFn := &arithmeticIntegerToFloatingPointFunc{arithmeticFunction{*NewScalarFunction("floor", Unary(), EmptyFuncDoc), decPromoteNone}}
kns = kernels.GetSimpleRoundKernels(kernels.RoundDown)
for _, k := range kns {
if err := floorFn.AddKernel(k); err != nil {
panic(err)
}
}
floorFn.AddNewKernel([]exec.InputType{exec.NewIDInput(arrow.DECIMAL128)},
kernels.OutputFirstType, kernels.FixedRoundDecimalExec[decimal128.Num](kernels.RoundDown), nil)
floorFn.AddNewKernel([]exec.InputType{exec.NewIDInput(arrow.DECIMAL256)},
kernels.OutputFirstType, kernels.FixedRoundDecimalExec[decimal256.Num](kernels.RoundDown), nil)
reg.AddFunction(floorFn, false)

ceilFn := &arithmeticIntegerToFloatingPointFunc{arithmeticFunction{*NewScalarFunction("ceil", Unary(), EmptyFuncDoc), decPromoteNone}}
kns = kernels.GetSimpleRoundKernels(kernels.RoundUp)
for _, k := range kns {
if err := ceilFn.AddKernel(k); err != nil {
panic(err)
}
}
ceilFn.AddNewKernel([]exec.InputType{exec.NewIDInput(arrow.DECIMAL128)},
kernels.OutputFirstType, kernels.FixedRoundDecimalExec[decimal128.Num](kernels.RoundUp), nil)
ceilFn.AddNewKernel([]exec.InputType{exec.NewIDInput(arrow.DECIMAL256)},
kernels.OutputFirstType, kernels.FixedRoundDecimalExec[decimal256.Num](kernels.RoundUp), nil)
reg.AddFunction(ceilFn, false)

truncFn := &arithmeticIntegerToFloatingPointFunc{arithmeticFunction{*NewScalarFunction("trunc", Unary(), EmptyFuncDoc), decPromoteNone}}
kns = kernels.GetSimpleRoundKernels(kernels.TowardsZero)
for _, k := range kns {
if err := truncFn.AddKernel(k); err != nil {
panic(err)
}
}
truncFn.AddNewKernel([]exec.InputType{exec.NewIDInput(arrow.DECIMAL128)},
kernels.OutputFirstType, kernels.FixedRoundDecimalExec[decimal128.Num](kernels.TowardsZero), nil)
truncFn.AddNewKernel([]exec.InputType{exec.NewIDInput(arrow.DECIMAL256)},
kernels.OutputFirstType, kernels.FixedRoundDecimalExec[decimal256.Num](kernels.TowardsZero), nil)
reg.AddFunction(truncFn, false)

roundFn := &arithmeticIntegerToFloatingPointFunc{arithmeticFunction{*NewScalarFunction("round", Unary(), EmptyFuncDoc), decPromoteNone}}
kns = kernels.GetRoundUnaryKernels(kernels.InitRoundState, kernels.UnaryRoundExec)
for _, k := range kns {
if err := roundFn.AddKernel(k); err != nil {
panic(err)
}
}

roundFn.defaultOpts = DefaultRoundOptions
reg.AddFunction(roundFn, false)

roundToMultipleFn := &arithmeticIntegerToFloatingPointFunc{arithmeticFunction{*NewScalarFunction("round_to_multiple", Unary(), EmptyFuncDoc), decPromoteNone}}
kns = kernels.GetRoundUnaryKernels(kernels.InitRoundToMultipleState, kernels.UnaryRoundToMultipleExec)
for _, k := range kns {
if err := roundToMultipleFn.AddKernel(k); err != nil {
panic(err)
}
}

roundToMultipleFn.defaultOpts = DefaultRoundToMultipleOptions
reg.AddFunction(roundToMultipleFn, false)
}

func impl(ctx context.Context, fn string, opts ArithmeticOptions, left, right Datum) (Datum, error) {
Expand Down Expand Up @@ -596,3 +786,99 @@ func ShiftRight(ctx context.Context, opts ArithmeticOptions, lhs, rhs Datum) (Da
}
return CallFunction(ctx, fn, nil, lhs, rhs)
}

func Sin(ctx context.Context, opts ArithmeticOptions, arg Datum) (Datum, error) {
fn := "sin"
if opts.NoCheckOverflow {
fn += "_unchecked"
}
return CallFunction(ctx, fn, nil, arg)
}

func Cos(ctx context.Context, opts ArithmeticOptions, arg Datum) (Datum, error) {
fn := "cos"
if opts.NoCheckOverflow {
fn += "_unchecked"
}
return CallFunction(ctx, fn, nil, arg)
}

func Tan(ctx context.Context, opts ArithmeticOptions, arg Datum) (Datum, error) {
fn := "tan"
if opts.NoCheckOverflow {
fn += "_unchecked"
}
return CallFunction(ctx, fn, nil, arg)
}

func Asin(ctx context.Context, opts ArithmeticOptions, arg Datum) (Datum, error) {
fn := "asin"
if opts.NoCheckOverflow {
fn += "_unchecked"
}
return CallFunction(ctx, fn, nil, arg)
}

func Acos(ctx context.Context, opts ArithmeticOptions, arg Datum) (Datum, error) {
fn := "acos"
if opts.NoCheckOverflow {
fn += "_unchecked"
}
return CallFunction(ctx, fn, nil, arg)
}

func Atan(ctx context.Context, arg Datum) (Datum, error) {
return CallFunction(ctx, "atan", nil, arg)
}

func Atan2(ctx context.Context, x, y Datum) (Datum, error) {
return CallFunction(ctx, "atan2", nil, x, y)
}

func Ln(ctx context.Context, opts ArithmeticOptions, arg Datum) (Datum, error) {
fn := "ln"
if opts.NoCheckOverflow {
fn += "_unchecked"
}
return CallFunction(ctx, fn, nil, arg)
}

func Log10(ctx context.Context, opts ArithmeticOptions, arg Datum) (Datum, error) {
fn := "log10"
if opts.NoCheckOverflow {
fn += "_unchecked"
}
return CallFunction(ctx, fn, nil, arg)
}

func Log2(ctx context.Context, opts ArithmeticOptions, arg Datum) (Datum, error) {
fn := "log2"
if opts.NoCheckOverflow {
fn += "_unchecked"
}
return CallFunction(ctx, fn, nil, arg)
}

func Log1p(ctx context.Context, opts ArithmeticOptions, arg Datum) (Datum, error) {
fn := "log1p"
if opts.NoCheckOverflow {
fn += "_unchecked"
}
return CallFunction(ctx, fn, nil, arg)
}

func Logb(ctx context.Context, opts ArithmeticOptions, x, base Datum) (Datum, error) {
fn := "logb"
if opts.NoCheckOverflow {
fn += "_unchecked"
}
return CallFunction(ctx, fn, nil, x, base)
}

func Round(ctx context.Context, opts RoundOptions, arg Datum) (Datum, error) {
return CallFunction(ctx, "round", &opts, arg)
}

func RoundToMultiple(ctx context.Context, opts RoundToMultipleOptions, arg Datum) (Datum, error) {
return CallFunction(ctx, "round_to_multiple", &opts, arg)
}

0 comments on commit 1c8853b

Please sign in to comment.