From ab5c1d2fea2af5317d403087f76bb5b1e8262544 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 31 Aug 2022 13:29:47 -0400 Subject: [PATCH] ARROW-17586: [Go] String To Numeric cast functions --- go/arrow/compute/cast.go | 6 ++ go/arrow/compute/cast_test.go | 101 +++++++++++++++++- go/arrow/compute/datum.go | 4 + go/arrow/compute/internal/kernels/helpers.go | 4 +- .../compute/internal/kernels/numeric_cast.go | 77 ++++++++++++- go/arrow/scalar/scalar.go | 2 + 6 files changed, 186 insertions(+), 8 deletions(-) diff --git a/go/arrow/compute/cast.go b/go/arrow/compute/cast.go index 6bdb5d767cba2..09a3c75c0a185 100644 --- a/go/arrow/compute/cast.go +++ b/go/arrow/compute/cast.go @@ -252,6 +252,12 @@ func CastArray(ctx context.Context, val arrow.Array, opts *CastOptions) (arrow.A return out.(*ArrayDatum).MakeArray(), nil } +// CastToType is a convenience function equivalent to calling +// CastArray(ctx, val, compute.SafeCastOptions(toType)) +func CastToType(ctx context.Context, val arrow.Array, toType arrow.DataType) (arrow.Array, error) { + return CastArray(ctx, val, SafeCastOptions(toType)) +} + // CanCast returns true if there is an implementation for casting an array // or scalar value from the specified DataType to the other data type. func CanCast(from, to arrow.DataType) bool { diff --git a/go/arrow/compute/cast_test.go b/go/arrow/compute/cast_test.go index 7c5c422713d8f..0e0191d13e654 100644 --- a/go/arrow/compute/cast_test.go +++ b/go/arrow/compute/cast_test.go @@ -198,16 +198,19 @@ func checkCastZeroCopy(t *testing.T, input arrow.Array, toType arrow.DataType, o } var ( - integerTypes = []arrow.DataType{ - arrow.PrimitiveTypes.Uint8, + signedIntTypes = []arrow.DataType{ arrow.PrimitiveTypes.Int8, - arrow.PrimitiveTypes.Uint16, arrow.PrimitiveTypes.Int16, - arrow.PrimitiveTypes.Uint32, arrow.PrimitiveTypes.Int32, - arrow.PrimitiveTypes.Uint64, arrow.PrimitiveTypes.Int64, } + unsignedIntTypes = []arrow.DataType{ + arrow.PrimitiveTypes.Uint8, + arrow.PrimitiveTypes.Uint16, + arrow.PrimitiveTypes.Uint32, + arrow.PrimitiveTypes.Uint64, + } + integerTypes = append(signedIntTypes, unsignedIntTypes...) numericTypes = append(integerTypes, arrow.PrimitiveTypes.Float32, arrow.PrimitiveTypes.Float64) @@ -1211,6 +1214,94 @@ func (c *CastSuite) TestDecimalToFloating() { } } +func (c *CastSuite) TestStringToInt() { + for _, stype := range []arrow.DataType{arrow.BinaryTypes.String, arrow.BinaryTypes.LargeString} { + for _, dt := range signedIntTypes { + c.checkCast(stype, dt, + `["0", null, "127", "-1", "0", "0x0", "0x7F"]`, + `[0, null, 127, -1, 0, 0, 127]`) + } + + c.checkCast(stype, arrow.PrimitiveTypes.Int32, + `["2147483647", null, "-2147483648", "0", "0X0", "0x7FFFFFFF", "-0X1", "-0x10000000"]`, + `[2147483647, null, -2147483648, 0, 0, 2147483647, -1, -268435456]`) + + c.checkCast(stype, arrow.PrimitiveTypes.Int64, + `["9223372036854775807", null, "-9223372036854775808", "0", "0x0", "0x7FFFFFFFFFFFFFFf", "-0x0FFFFFFFFFFFFFFF"]`, + `[9223372036854775807, null, -9223372036854775808, 0, 0, 9223372036854775807, -1152921504606846975]`) + + for _, dt := range unsignedIntTypes { + c.checkCast(stype, dt, `["0", null, "127", "255", "0", "0x0", "0xff", "0X7f"]`, + `[0, null, 127, 255, 0, 0, 255, 127]`) + } + + c.checkCast(stype, arrow.PrimitiveTypes.Uint32, + `["2147483647", null, "4294967295", "0", "0x0", "0x7FFFFFFf", "0xFFFFFFFF"]`, + `[2147483647, null, 4294967295, 0, 0, 2147483647, 4294967295]`) + + c.checkCast(stype, arrow.PrimitiveTypes.Uint64, + `["9223372036854775807", null, "18446744073709551615", "0", "0x0", "0x7FFFFFFFFFFFFFFf", "0xfFFFFFFFFFFFFFFf"]`, + `[9223372036854775807, null, 18446744073709551615, 0, 0, 9223372036854775807, 18446744073709551615]`) + + for _, notInt8 := range []string{"z", "12 z", "128", "-129", "0.5", "0x", "0xfff", "-0xf0"} { + c.checkCastFails(stype, `["`+notInt8+`"]`, compute.SafeCastOptions(arrow.PrimitiveTypes.Int8)) + } + + for _, notUint8 := range []string{"256", "-1", "0.5", "0x", "0x3wa", "0x123"} { + c.checkCastFails(stype, `["`+notUint8+`"]`, compute.SafeCastOptions(arrow.PrimitiveTypes.Uint8)) + } + } +} + +func (c *CastSuite) TestStringToFloating() { + for _, stype := range []arrow.DataType{arrow.BinaryTypes.String, arrow.BinaryTypes.LargeString} { + for _, dt := range []arrow.DataType{arrow.PrimitiveTypes.Float32, arrow.PrimitiveTypes.Float64} { + c.checkCast(stype, dt, `["0.1", null, "127.3", "1e3", "200.4", "0.5"]`, + `[0.1, null, 127.3, 1000, 200.4, 0.5]`) + + for _, notFloat := range []string{"z"} { + c.checkCastFails(stype, `["`+notFloat+`"]`, compute.SafeCastOptions(dt)) + } + } + } +} + +func (c *CastSuite) TestUnsupportedInputType() { + // casting to a supported target type, but with an unsupported + // input for that target type. + arr, _, _ := array.FromJSON(c.mem, arrow.PrimitiveTypes.Int32, strings.NewReader(`[1, 2, 3]`)) + defer arr.Release() + + toType := arrow.ListOf(arrow.BinaryTypes.String) + _, err := compute.CastToType(context.Background(), arr, toType) + c.ErrorIs(err, arrow.ErrNotImplemented) + c.ErrorContains(err, "unsupported cast to list from int32") + + // test calling through the generic kernel API + datum := compute.NewDatum(arr) + defer datum.Release() + _, err = compute.CallFunction(context.Background(), "cast", compute.SafeCastOptions(toType), datum) + c.ErrorIs(err, arrow.ErrNotImplemented) + c.ErrorContains(err, "unsupported cast to list from int32") +} + +func (c *CastSuite) TestUnsupportedTargetType() { + arr, _, _ := array.FromJSON(c.mem, arrow.PrimitiveTypes.Int32, strings.NewReader(`[1, 2, 3]`)) + defer arr.Release() + + toType := arrow.DenseUnionOf([]arrow.Field{{Name: "a", Type: arrow.PrimitiveTypes.Int32}}, []arrow.UnionTypeCode{0}) + _, err := compute.CastToType(context.Background(), arr, toType) + c.ErrorIs(err, arrow.ErrNotImplemented) + c.ErrorContains(err, "unsupported cast to dense_union from int32") + + // test calling through the generic kernel API + datum := compute.NewDatum(arr) + defer datum.Release() + _, err = compute.CallFunction(context.Background(), "cast", compute.SafeCastOptions(toType), datum) + c.ErrorIs(err, arrow.ErrNotImplemented) + c.ErrorContains(err, "unsupported cast to dense_union from int32") +} + func (c *CastSuite) checkCastZeroCopy(dt arrow.DataType, json string) { arr, _, _ := array.FromJSON(c.mem, dt, strings.NewReader(json)) defer arr.Release() diff --git a/go/arrow/compute/datum.go b/go/arrow/compute/datum.go index 4243344637753..b5b88613b4447 100644 --- a/go/arrow/compute/datum.go +++ b/go/arrow/compute/datum.go @@ -122,6 +122,10 @@ type releasable interface { } func (d *ScalarDatum) Release() { + if !d.Value.IsValid() { + return + } + if v, ok := d.Value.(releasable); ok { v.Release() } diff --git a/go/arrow/compute/internal/kernels/helpers.go b/go/arrow/compute/internal/kernels/helpers.go index 5940a071a2f9c..087ccbf88d55c 100644 --- a/go/arrow/compute/internal/kernels/helpers.go +++ b/go/arrow/compute/internal/kernels/helpers.go @@ -123,7 +123,7 @@ func ScalarUnaryNotNullBinaryArgBoolOut[OffsetT int32 | int64](defVal bool, op f // It implements the handling to iterate the offsets and values calling // the provided function on each byte slice. The zero value of the OutT // will be used as the output for elements of the input that are null. -func ScalarUnaryNotNullBinaryArg[OutT exec.FixedWidthTypes, OffsetT int32 | int64](op func(*exec.KernelCtx, []byte) (OutT, error)) exec.ArrayKernelExec { +func ScalarUnaryNotNullBinaryArg[OutT exec.FixedWidthTypes, OffsetT int32 | int64](op func(*exec.KernelCtx, []byte, *error) OutT) exec.ArrayKernelExec { return func(ctx *exec.KernelCtx, in *exec.ExecSpan, out *exec.ExecResult) error { var ( arg0 = &in.Values[0].Array @@ -139,7 +139,7 @@ func ScalarUnaryNotNullBinaryArg[OutT exec.FixedWidthTypes, OffsetT int32 | int6 bitutils.VisitBitBlocks(bitmap, arg0.Offset, arg0.Len, func(pos int64) { v := arg0Data[arg0Offsets[pos]:arg0Offsets[pos+1]] - outData[outPos], err = op(ctx, v) + outData[outPos] = op(ctx, v, &err) outPos++ }, func() { outData[outPos] = def diff --git a/go/arrow/compute/internal/kernels/numeric_cast.go b/go/arrow/compute/internal/kernels/numeric_cast.go index 0856e3d0651c9..40ba59285dd9d 100644 --- a/go/arrow/compute/internal/kernels/numeric_cast.go +++ b/go/arrow/compute/internal/kernels/numeric_cast.go @@ -18,6 +18,8 @@ package kernels import ( "fmt" + "strconv" + "unsafe" "github.com/apache/arrow/go/v10/arrow" "github.com/apache/arrow/go/v10/arrow/bitutil" @@ -661,6 +663,70 @@ func checkIntToFloatTrunc(in *exec.ArraySpan, outType arrow.Type) error { return nil } +func parseStringToNumberImpl[T exec.IntTypes | exec.UintTypes | exec.FloatTypes, OffsetT int32 | int64](parseFn func(string) (T, error)) exec.ArrayKernelExec { + return ScalarUnaryNotNullBinaryArg[T, OffsetT](func(_ *exec.KernelCtx, in []byte, err *error) T { + st := *(*string)(unsafe.Pointer(&in)) + v, e := parseFn(st) + if e != nil { + *err = fmt.Errorf("%w: %s", arrow.ErrInvalid, e) + } + return v + }) +} + +func getParseStringExec[OffsetT int32 | int64](out arrow.Type) exec.ArrayKernelExec { + switch out { + case arrow.INT8: + return parseStringToNumberImpl[int8, OffsetT](func(s string) (int8, error) { + v, err := strconv.ParseInt(s, 0, 8) + return int8(v), err + }) + case arrow.UINT8: + return parseStringToNumberImpl[uint8, OffsetT](func(s string) (uint8, error) { + v, err := strconv.ParseUint(s, 0, 8) + return uint8(v), err + }) + case arrow.INT16: + return parseStringToNumberImpl[int16, OffsetT](func(s string) (int16, error) { + v, err := strconv.ParseInt(s, 0, 16) + return int16(v), err + }) + case arrow.UINT16: + return parseStringToNumberImpl[uint16, OffsetT](func(s string) (uint16, error) { + v, err := strconv.ParseUint(s, 0, 16) + return uint16(v), err + }) + case arrow.INT32: + return parseStringToNumberImpl[int32, OffsetT](func(s string) (int32, error) { + v, err := strconv.ParseInt(s, 0, 32) + return int32(v), err + }) + case arrow.UINT32: + return parseStringToNumberImpl[uint32, OffsetT](func(s string) (uint32, error) { + v, err := strconv.ParseUint(s, 0, 32) + return uint32(v), err + }) + case arrow.INT64: + return parseStringToNumberImpl[int64, OffsetT](func(s string) (int64, error) { + return strconv.ParseInt(s, 0, 64) + }) + case arrow.UINT64: + return parseStringToNumberImpl[uint64, OffsetT](func(s string) (uint64, error) { + return strconv.ParseUint(s, 0, 64) + }) + case arrow.FLOAT32: + return parseStringToNumberImpl[float32, OffsetT](func(s string) (float32, error) { + v, err := strconv.ParseFloat(s, 32) + return float32(v), err + }) + case arrow.FLOAT64: + return parseStringToNumberImpl[float64, OffsetT](func(s string) (float64, error) { + return strconv.ParseFloat(s, 64) + }) + } + panic("invalid type for getParseStringExec") +} + func addCommonNumberCasts[T numeric](outTy arrow.DataType, kernels []exec.ScalarKernel) []exec.ScalarKernel { kernels = append(kernels, GetCommonCastKernels(outTy.ID(), exec.NewOutputType(outTy))...) @@ -668,7 +734,16 @@ func addCommonNumberCasts[T numeric](outTy arrow.DataType, kernels []exec.Scalar []exec.InputType{exec.NewExactInput(arrow.FixedWidthTypes.Boolean)}, exec.NewOutputType(outTy), ScalarUnaryBoolArg(boolToNum[T]), nil)) - // generatevarbinarybase + for _, inTy := range []arrow.DataType{arrow.BinaryTypes.Binary, arrow.BinaryTypes.String} { + kernels = append(kernels, exec.NewScalarKernel( + []exec.InputType{exec.NewExactInput(inTy)}, exec.NewOutputType(outTy), + getParseStringExec[int32](outTy.ID()), nil)) + } + for _, inTy := range []arrow.DataType{arrow.BinaryTypes.LargeBinary, arrow.BinaryTypes.LargeString} { + kernels = append(kernels, exec.NewScalarKernel( + []exec.InputType{exec.NewExactInput(inTy)}, exec.NewOutputType(outTy), + getParseStringExec[int64](outTy.ID()), nil)) + } return kernels } diff --git a/go/arrow/scalar/scalar.go b/go/arrow/scalar/scalar.go index a35eb519bad5f..443176d4b7ce4 100644 --- a/go/arrow/scalar/scalar.go +++ b/go/arrow/scalar/scalar.go @@ -638,6 +638,8 @@ func GetScalar(arr arrow.Array, idx int) (Scalar, error) { return ScalarNull, nil case *array.String: return NewStringScalar(arr.Value(idx)), nil + case *array.LargeString: + return NewLargeStringScalar(arr.Value(idx)), nil case *array.Struct: children := make(Vector, arr.NumField()) for i := range children {