Skip to content

Commit

Permalink
GH-40693: [Go] Fix Decimal type precision loss on GetOneForMarshal (#…
Browse files Browse the repository at this point in the history
…40694)

### Rationale for this change

Loss of precision when using `GetOneForMarshal` on `Decimal128` and `Decimal256`

### What changes are included in this PR?

Fixes for precision loss with `DecimalType.GetOneForMarshal`

* GitHub Issue: #40693

Lead-authored-by: Herman Schaaf <hermanschaaf@gmail.com>
Co-authored-by: Kemal Hadimli <disq@users.noreply.github.com>
Signed-off-by: Matt Topol <zotthewizard@gmail.com>
  • Loading branch information
disq and disq committed Mar 25, 2024
1 parent cc771a0 commit 1781b32
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 12 deletions.
13 changes: 8 additions & 5 deletions go/arrow/array/decimal128.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package array
import (
"bytes"
"fmt"
"math"
"math/big"
"reflect"
"strings"
Expand Down Expand Up @@ -86,15 +85,19 @@ func (a *Decimal128) setData(data *Data) {
a.values = a.values[beg:end]
}
}

func (a *Decimal128) GetOneForMarshal(i int) interface{} {
if a.IsNull(i) {
return nil
}

typ := a.DataType().(*arrow.Decimal128Type)
f := (&big.Float{}).SetInt(a.Value(i).BigInt())
f.Quo(f, big.NewFloat(math.Pow10(int(typ.Scale))))
n := a.Value(i)
scale := typ.Scale
f := (&big.Float{}).SetInt(n.BigInt())
if scale < 0 {
f.SetPrec(128).Mul(f, (&big.Float{}).SetInt(decimal128.GetScaleMultiplier(int(-scale)).BigInt()))
} else {
f.SetPrec(128).Quo(f, (&big.Float{}).SetInt(decimal128.GetScaleMultiplier(int(scale)).BigInt()))
}
return f.Text('g', int(typ.Precision))
}

Expand Down
59 changes: 58 additions & 1 deletion go/arrow/array/decimal128_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,17 @@ func TestDecimal128StringRoundTrip(t *testing.T) {
decimal128.FromI64(9),
decimal128.FromI64(10),
}
valid := []bool{true, true, true, false, true, true, false, true, true, true}
val1, err := decimal128.FromString("0.99", dt.Precision, dt.Scale)
if err != nil {
t.Fatal(err)
}
val2, err := decimal128.FromString("1234567890.12345", dt.Precision, dt.Scale)
if err != nil {
t.Fatal(err)
}
values = append(values, val1, val2)

valid := []bool{true, true, true, false, true, true, false, true, true, true, true, true}

b.AppendValues(values, valid)

Expand All @@ -224,3 +234,50 @@ func TestDecimal128StringRoundTrip(t *testing.T) {

assert.True(t, array.Equal(arr, arr1))
}

func TestDecimal128GetOneForMarshal(t *testing.T) {
mem := memory.NewCheckedAllocator(memory.NewGoAllocator())
defer mem.AssertSize(t, 0)

dtype := &arrow.Decimal128Type{Precision: 38, Scale: 20}

b := array.NewDecimal128Builder(mem, dtype)
defer b.Release()

cases := []struct {
give any
want any
}{
{"1", "1"},
{"1.25", "1.25"},
{"0.99", "0.99"},
{"1234567890.123456789", "1234567890.123456789"},
{nil, nil},
{"-0.99", "-0.99"},
{"-1234567890.123456789", "-1234567890.123456789"},
{"0.0000000000000000001", "1e-19"},
}
for _, v := range cases {
if v.give == nil {
b.AppendNull()
continue
}

dt, err := decimal128.FromString(v.give.(string), dtype.Precision, dtype.Scale)
if err != nil {
t.Fatal(err)
}
b.Append(dt)
}

arr := b.NewDecimal128Array()
defer arr.Release()

if got, want := arr.Len(), len(cases); got != want {
t.Fatalf("invalid array length: got=%d, want=%d", got, want)
}

for i := range cases {
assert.Equalf(t, cases[i].want, arr.GetOneForMarshal(i), "unexpected value at index %d", i)
}
}
12 changes: 8 additions & 4 deletions go/arrow/array/decimal256.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package array
import (
"bytes"
"fmt"
"math"
"math/big"
"reflect"
"strings"
Expand Down Expand Up @@ -91,10 +90,15 @@ func (a *Decimal256) GetOneForMarshal(i int) interface{} {
if a.IsNull(i) {
return nil
}

typ := a.DataType().(*arrow.Decimal256Type)
f := (&big.Float{}).SetInt(a.Value(i).BigInt())
f.Quo(f, big.NewFloat(math.Pow10(int(typ.Scale))))
n := a.Value(i)
scale := typ.Scale
f := (&big.Float{}).SetInt(n.BigInt())
if scale < 0 {
f.SetPrec(256).Mul(f, (&big.Float{}).SetInt(decimal256.GetScaleMultiplier(int(-scale)).BigInt()))
} else {
f.SetPrec(256).Quo(f, (&big.Float{}).SetInt(decimal256.GetScaleMultiplier(int(scale)).BigInt()))
}
return f.Text('g', int(typ.Precision))
}

Expand Down
70 changes: 68 additions & 2 deletions go/arrow/array/decimal256_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,17 @@ func TestDecimal256StringRoundTrip(t *testing.T) {
decimal256.FromI64(9),
decimal256.FromI64(10),
}
valid := []bool{true, true, true, false, true, true, false, true, true, true}
val1, err := decimal256.FromString("0.99", dt.Precision, dt.Scale)
if err != nil {
t.Fatal(err)
}
val2, err := decimal256.FromString("1234567890.123456789", dt.Precision, dt.Scale)
if err != nil {
t.Fatal(err)
}
values = append(values, val1, val2)

valid := []bool{true, true, true, false, true, true, false, true, true, true, true, true}

b.AppendValues(values, valid)

Expand All @@ -217,11 +227,67 @@ func TestDecimal256StringRoundTrip(t *testing.T) {
defer b1.Release()

for i := 0; i < arr.Len(); i++ {
assert.NoError(t, b1.AppendValueFromString(arr.ValueStr(i)))
v := arr.ValueStr(i)
assert.NoError(t, b1.AppendValueFromString(v))
}

arr1 := b1.NewArray().(*array.Decimal256)
defer arr1.Release()

for i := 0; i < arr.Len(); i++ {
if arr.IsNull(i) && arr1.IsNull(i) {
continue
}
if arr.Value(i) != arr1.Value(i) {
t.Fatalf("unexpected value at index %d: got=%v, want=%v", i, arr1.Value(i), arr.Value(i))
}
}
assert.True(t, array.Equal(arr, arr1))
}

func TestDecimal256GetOneForMarshal(t *testing.T) {
mem := memory.NewCheckedAllocator(memory.NewGoAllocator())
defer mem.AssertSize(t, 0)

dtype := &arrow.Decimal256Type{Precision: 38, Scale: 20}

b := array.NewDecimal256Builder(mem, dtype)
defer b.Release()

cases := []struct {
give any
want any
}{
{"1", "1"},
{"1.25", "1.25"},
{"0.99", "0.99"},
{"1234567890.123456789", "1234567890.123456789"},
{nil, nil},
{"-0.99", "-0.99"},
{"-1234567890.123456789", "-1234567890.123456789"},
{"0.0000000000000000001", "1e-19"},
}
for _, v := range cases {
if v.give == nil {
b.AppendNull()
continue
}

dt, err := decimal256.FromString(v.give.(string), dtype.Precision, dtype.Scale)
if err != nil {
t.Fatal(err)
}
b.Append(dt)
}

arr := b.NewDecimal256Array()
defer arr.Release()

if got, want := arr.Len(), len(cases); got != want {
t.Fatalf("invalid array length: got=%d, want=%d", got, want)
}

for i := range cases {
assert.Equalf(t, cases[i].want, arr.GetOneForMarshal(i), "unexpected value at index %d", i)
}
}

0 comments on commit 1781b32

Please sign in to comment.