Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-17586: [Go] String To Numeric cast functions #14015

Merged
merged 2 commits into from
Sep 1, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions go/arrow/compute/cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,12 @@ func CastArray(ctx context.Context, val arrow.Array, opts *CastOptions) (arrow.A
return out.(*ArrayDatum).MakeArray(), nil
}

// CastToType is a convenience function equivalent to calling
// CastArray(ctx, val, compute.SafeCastOptions(toType))
func CastToType(ctx context.Context, val arrow.Array, toType arrow.DataType) (arrow.Array, error) {
return CastArray(ctx, val, SafeCastOptions(toType))
}

// CanCast returns true if there is an implementation for casting an array
// or scalar value from the specified DataType to the other data type.
func CanCast(from, to arrow.DataType) bool {
Expand Down
101 changes: 96 additions & 5 deletions go/arrow/compute/cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,16 +198,19 @@ func checkCastZeroCopy(t *testing.T, input arrow.Array, toType arrow.DataType, o
}

var (
integerTypes = []arrow.DataType{
arrow.PrimitiveTypes.Uint8,
signedIntTypes = []arrow.DataType{
arrow.PrimitiveTypes.Int8,
arrow.PrimitiveTypes.Uint16,
arrow.PrimitiveTypes.Int16,
arrow.PrimitiveTypes.Uint32,
arrow.PrimitiveTypes.Int32,
arrow.PrimitiveTypes.Uint64,
arrow.PrimitiveTypes.Int64,
}
unsignedIntTypes = []arrow.DataType{
arrow.PrimitiveTypes.Uint8,
arrow.PrimitiveTypes.Uint16,
arrow.PrimitiveTypes.Uint32,
arrow.PrimitiveTypes.Uint64,
}
integerTypes = append(signedIntTypes, unsignedIntTypes...)
numericTypes = append(integerTypes,
arrow.PrimitiveTypes.Float32,
arrow.PrimitiveTypes.Float64)
Expand Down Expand Up @@ -1211,6 +1214,94 @@ func (c *CastSuite) TestDecimalToFloating() {
}
}

func (c *CastSuite) TestStringToInt() {
for _, stype := range []arrow.DataType{arrow.BinaryTypes.String, arrow.BinaryTypes.LargeString} {
for _, dt := range signedIntTypes {
c.checkCast(stype, dt,
`["0", null, "127", "-1", "0", "0x0", "0x7F"]`,
`[0, null, 127, -1, 0, 0, 127]`)
}

c.checkCast(stype, arrow.PrimitiveTypes.Int32,
`["2147483647", null, "-2147483648", "0", "0X0", "0x7FFFFFFF", "-0X1", "-0x10000000"]`,
`[2147483647, null, -2147483648, 0, 0, 2147483647, -1, -268435456]`)

c.checkCast(stype, arrow.PrimitiveTypes.Int64,
`["9223372036854775807", null, "-9223372036854775808", "0", "0x0", "0x7FFFFFFFFFFFFFFf", "-0x0FFFFFFFFFFFFFFF"]`,
`[9223372036854775807, null, -9223372036854775808, 0, 0, 9223372036854775807, -1152921504606846975]`)

for _, dt := range unsignedIntTypes {
c.checkCast(stype, dt, `["0", null, "127", "255", "0", "0x0", "0xff", "0X7f"]`,
`[0, null, 127, 255, 0, 0, 255, 127]`)
}

c.checkCast(stype, arrow.PrimitiveTypes.Uint32,
`["2147483647", null, "4294967295", "0", "0x0", "0x7FFFFFFf", "0xFFFFFFFF"]`,
`[2147483647, null, 4294967295, 0, 0, 2147483647, 4294967295]`)

c.checkCast(stype, arrow.PrimitiveTypes.Uint64,
`["9223372036854775807", null, "18446744073709551615", "0", "0x0", "0x7FFFFFFFFFFFFFFf", "0xfFFFFFFFFFFFFFFf"]`,
`[9223372036854775807, null, 18446744073709551615, 0, 0, 9223372036854775807, 18446744073709551615]`)

for _, notInt8 := range []string{"z", "12 z", "128", "-129", "0.5", "0x", "0xfff", "-0xf0"} {
c.checkCastFails(stype, `["`+notInt8+`"]`, compute.SafeCastOptions(arrow.PrimitiveTypes.Int8))
}

for _, notUint8 := range []string{"256", "-1", "0.5", "0x", "0x3wa", "0x123"} {
c.checkCastFails(stype, `["`+notUint8+`"]`, compute.SafeCastOptions(arrow.PrimitiveTypes.Uint8))
}
}
}

func (c *CastSuite) TestStringToFloating() {
for _, stype := range []arrow.DataType{arrow.BinaryTypes.String, arrow.BinaryTypes.LargeString} {
for _, dt := range []arrow.DataType{arrow.PrimitiveTypes.Float32, arrow.PrimitiveTypes.Float64} {
c.checkCast(stype, dt, `["0.1", null, "127.3", "1e3", "200.4", "0.5"]`,
`[0.1, null, 127.3, 1000, 200.4, 0.5]`)

for _, notFloat := range []string{"z"} {
c.checkCastFails(stype, `["`+notFloat+`"]`, compute.SafeCastOptions(dt))
}
}
}
}

func (c *CastSuite) TestUnsupportedInputType() {
// casting to a supported target type, but with an unsupported
// input for that target type.
arr, _, _ := array.FromJSON(c.mem, arrow.PrimitiveTypes.Int32, strings.NewReader(`[1, 2, 3]`))
defer arr.Release()

toType := arrow.ListOf(arrow.BinaryTypes.String)
_, err := compute.CastToType(context.Background(), arr, toType)
c.ErrorIs(err, arrow.ErrNotImplemented)
c.ErrorContains(err, "unsupported cast to list<item: utf8, nullable> from int32")

// test calling through the generic kernel API
datum := compute.NewDatum(arr)
defer datum.Release()
_, err = compute.CallFunction(context.Background(), "cast", compute.SafeCastOptions(toType), datum)
c.ErrorIs(err, arrow.ErrNotImplemented)
c.ErrorContains(err, "unsupported cast to list<item: utf8, nullable> from int32")
}

func (c *CastSuite) TestUnsupportedTargetType() {
arr, _, _ := array.FromJSON(c.mem, arrow.PrimitiveTypes.Int32, strings.NewReader(`[1, 2, 3]`))
defer arr.Release()

toType := arrow.DenseUnionOf([]arrow.Field{{Name: "a", Type: arrow.PrimitiveTypes.Int32}}, []arrow.UnionTypeCode{0})
_, err := compute.CastToType(context.Background(), arr, toType)
c.ErrorIs(err, arrow.ErrNotImplemented)
c.ErrorContains(err, "unsupported cast to dense_union<a: type=int32=0> from int32")

// test calling through the generic kernel API
datum := compute.NewDatum(arr)
defer datum.Release()
Comment on lines +1304 to +1305
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See this pattern often. Not familiar with go arrow implementation. Just curious why we need to release things explicitly. What release does?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The primary Arrow docs explain this: https://github.com/apache/arrow/tree/master/go#reference-counting

Essentially the objects utilize reference counting to know when it can eagerly release buffers and re-use objects / memory (particularly with custom Allocators)

_, err = compute.CallFunction(context.Background(), "cast", compute.SafeCastOptions(toType), datum)
c.ErrorIs(err, arrow.ErrNotImplemented)
c.ErrorContains(err, "unsupported cast to dense_union<a: type=int32=0> from int32")
}

func (c *CastSuite) checkCastZeroCopy(dt arrow.DataType, json string) {
arr, _, _ := array.FromJSON(c.mem, dt, strings.NewReader(json))
defer arr.Release()
Expand Down
4 changes: 4 additions & 0 deletions go/arrow/compute/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ type releasable interface {
}

func (d *ScalarDatum) Release() {
if !d.Value.IsValid() {
return
}

if v, ok := d.Value.(releasable); ok {
v.Release()
}
Expand Down
4 changes: 2 additions & 2 deletions go/arrow/compute/internal/kernels/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func ScalarUnaryNotNullBinaryArgBoolOut[OffsetT int32 | int64](defVal bool, op f
// It implements the handling to iterate the offsets and values calling
// the provided function on each byte slice. The zero value of the OutT
// will be used as the output for elements of the input that are null.
func ScalarUnaryNotNullBinaryArg[OutT exec.FixedWidthTypes, OffsetT int32 | int64](op func(*exec.KernelCtx, []byte) (OutT, error)) exec.ArrayKernelExec {
func ScalarUnaryNotNullBinaryArg[OutT exec.FixedWidthTypes, OffsetT int32 | int64](op func(*exec.KernelCtx, []byte, *error) OutT) exec.ArrayKernelExec {
return func(ctx *exec.KernelCtx, in *exec.ExecSpan, out *exec.ExecResult) error {
var (
arg0 = &in.Values[0].Array
Expand All @@ -139,7 +139,7 @@ func ScalarUnaryNotNullBinaryArg[OutT exec.FixedWidthTypes, OffsetT int32 | int6
bitutils.VisitBitBlocks(bitmap, arg0.Offset, arg0.Len,
func(pos int64) {
v := arg0Data[arg0Offsets[pos]:arg0Offsets[pos+1]]
outData[outPos], err = op(ctx, v)
outData[outPos] = op(ctx, v, &err)
outPos++
}, func() {
outData[outPos] = def
Expand Down
77 changes: 76 additions & 1 deletion go/arrow/compute/internal/kernels/numeric_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package kernels

import (
"fmt"
"strconv"
"unsafe"

"github.com/apache/arrow/go/v10/arrow"
"github.com/apache/arrow/go/v10/arrow/bitutil"
Expand Down Expand Up @@ -661,14 +663,87 @@ func checkIntToFloatTrunc(in *exec.ArraySpan, outType arrow.Type) error {
return nil
}

func parseStringToNumberImpl[T exec.IntTypes | exec.UintTypes | exec.FloatTypes, OffsetT int32 | int64](parseFn func(string) (T, error)) exec.ArrayKernelExec {
return ScalarUnaryNotNullBinaryArg[T, OffsetT](func(_ *exec.KernelCtx, in []byte, err *error) T {
st := *(*string)(unsafe.Pointer(&in))
v, e := parseFn(st)
if e != nil {
*err = fmt.Errorf("%w: %s", arrow.ErrInvalid, e)
}
return v
})
}

func getParseStringExec[OffsetT int32 | int64](out arrow.Type) exec.ArrayKernelExec {
switch out {
case arrow.INT8:
return parseStringToNumberImpl[int8, OffsetT](func(s string) (int8, error) {
v, err := strconv.ParseInt(s, 0, 8)
return int8(v), err
})
case arrow.UINT8:
return parseStringToNumberImpl[uint8, OffsetT](func(s string) (uint8, error) {
v, err := strconv.ParseUint(s, 0, 8)
return uint8(v), err
})
case arrow.INT16:
return parseStringToNumberImpl[int16, OffsetT](func(s string) (int16, error) {
v, err := strconv.ParseInt(s, 0, 16)
return int16(v), err
})
case arrow.UINT16:
return parseStringToNumberImpl[uint16, OffsetT](func(s string) (uint16, error) {
v, err := strconv.ParseUint(s, 0, 16)
return uint16(v), err
})
case arrow.INT32:
return parseStringToNumberImpl[int32, OffsetT](func(s string) (int32, error) {
v, err := strconv.ParseInt(s, 0, 32)
return int32(v), err
})
case arrow.UINT32:
return parseStringToNumberImpl[uint32, OffsetT](func(s string) (uint32, error) {
v, err := strconv.ParseUint(s, 0, 32)
return uint32(v), err
})
case arrow.INT64:
return parseStringToNumberImpl[int64, OffsetT](func(s string) (int64, error) {
return strconv.ParseInt(s, 0, 64)
})
case arrow.UINT64:
return parseStringToNumberImpl[uint64, OffsetT](func(s string) (uint64, error) {
return strconv.ParseUint(s, 0, 64)
})
case arrow.FLOAT32:
return parseStringToNumberImpl[float32, OffsetT](func(s string) (float32, error) {
v, err := strconv.ParseFloat(s, 32)
return float32(v), err
})
case arrow.FLOAT64:
return parseStringToNumberImpl[float64, OffsetT](func(s string) (float64, error) {
return strconv.ParseFloat(s, 64)
})
}
panic("invalid type for getParseStringExec")
}

func addCommonNumberCasts[T numeric](outTy arrow.DataType, kernels []exec.ScalarKernel) []exec.ScalarKernel {
kernels = append(kernels, GetCommonCastKernels(outTy.ID(), exec.NewOutputType(outTy))...)

kernels = append(kernels, exec.NewScalarKernel(
[]exec.InputType{exec.NewExactInput(arrow.FixedWidthTypes.Boolean)},
exec.NewOutputType(outTy), ScalarUnaryBoolArg(boolToNum[T]), nil))

// generatevarbinarybase
for _, inTy := range []arrow.DataType{arrow.BinaryTypes.Binary, arrow.BinaryTypes.String} {
kernels = append(kernels, exec.NewScalarKernel(
[]exec.InputType{exec.NewExactInput(inTy)}, exec.NewOutputType(outTy),
getParseStringExec[int32](outTy.ID()), nil))
}
for _, inTy := range []arrow.DataType{arrow.BinaryTypes.LargeBinary, arrow.BinaryTypes.LargeString} {
kernels = append(kernels, exec.NewScalarKernel(
[]exec.InputType{exec.NewExactInput(inTy)}, exec.NewOutputType(outTy),
getParseStringExec[int64](outTy.ID()), nil))
}
return kernels
}

Expand Down
2 changes: 2 additions & 0 deletions go/arrow/scalar/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,8 @@ func GetScalar(arr arrow.Array, idx int) (Scalar, error) {
return ScalarNull, nil
case *array.String:
return NewStringScalar(arr.Value(idx)), nil
case *array.LargeString:
return NewLargeStringScalar(arr.Value(idx)), nil
case *array.Struct:
children := make(Vector, arr.NumField())
for i := range children {
Expand Down