diff --git a/go/parquet/internal/encoding/boolean_encoder.go b/go/parquet/internal/encoding/boolean_encoder.go index fc9cd2728acda..617eaa9471ec6 100644 --- a/go/parquet/internal/encoding/boolean_encoder.go +++ b/go/parquet/internal/encoding/boolean_encoder.go @@ -74,11 +74,11 @@ func (enc *PlainBooleanEncoder) EstimatedDataEncodedSize() int64 { // FlushValues returns the buffered data, the responsibility is on the caller // to release the buffer memory -func (enc *PlainBooleanEncoder) FlushValues() Buffer { +func (enc *PlainBooleanEncoder) FlushValues() (Buffer, error) { if enc.wr.Pos() > 0 { toFlush := int(enc.wr.Pos()) enc.append(enc.bitsBuffer[:bitutil.BytesForBits(int64(toFlush))]) } - return enc.sink.Finish() + return enc.sink.Finish(), nil } diff --git a/go/parquet/internal/encoding/delta_bit_packing.go b/go/parquet/internal/encoding/delta_bit_packing.go index babd0b1fa97a2..f5a3867208be9 100644 --- a/go/parquet/internal/encoding/delta_bit_packing.go +++ b/go/parquet/internal/encoding/delta_bit_packing.go @@ -428,7 +428,8 @@ func (enc *deltaBitPackEncoder) putInternal(data interface{}) { } // FlushValues flushes any remaining data and returns the finished encoded buffer -func (enc *deltaBitPackEncoder) FlushValues() Buffer { +// or returns nil and any error encountered during flushing. +func (enc *deltaBitPackEncoder) FlushValues() (Buffer, error) { if enc.bitWriter != nil { // write any remaining values enc.flushBlock() @@ -457,7 +458,7 @@ func (enc *deltaBitPackEncoder) FlushValues() Buffer { buffer = append(buffer, flushed.Buf()[:enc.bitWriter.Written()]...) } - return poolBuffer{memory.NewBufferBytes(buffer)} + return poolBuffer{memory.NewBufferBytes(buffer)}, nil } // EstimatedDataEncodedSize returns the current amount of data actually flushed out and written diff --git a/go/parquet/internal/encoding/delta_byte_array.go b/go/parquet/internal/encoding/delta_byte_array.go index d11413ea236c7..c97e8914dfe3c 100644 --- a/go/parquet/internal/encoding/delta_byte_array.go +++ b/go/parquet/internal/encoding/delta_byte_array.go @@ -105,21 +105,28 @@ func (enc *DeltaByteArrayEncoder) PutSpaced(in []parquet.ByteArray, validBits [] } // Flush flushes any remaining data out and returns the finished encoded buffer. -func (enc *DeltaByteArrayEncoder) FlushValues() Buffer { +// or returns nil and any error encountered during flushing. +func (enc *DeltaByteArrayEncoder) FlushValues() (Buffer, error) { if enc.prefixEncoder == nil { enc.initEncoders() } - prefixBuf := enc.prefixEncoder.FlushValues() + prefixBuf, err := enc.prefixEncoder.FlushValues() + if err != nil { + return nil, err + } defer prefixBuf.Release() - suffixBuf := enc.suffixEncoder.FlushValues() + suffixBuf, err := enc.suffixEncoder.FlushValues() + if err != nil { + return nil, err + } defer suffixBuf.Release() ret := bufferPool.Get().(*memory.Buffer) ret.ResizeNoShrink(prefixBuf.Len() + suffixBuf.Len()) copy(ret.Bytes(), prefixBuf.Bytes()) copy(ret.Bytes()[prefixBuf.Len():], suffixBuf.Bytes()) - return poolBuffer{ret} + return poolBuffer{ret}, nil } // DeltaByteArrayDecoder is a decoder for a column of data encoded using incremental or prefix encoding. diff --git a/go/parquet/internal/encoding/delta_length_byte_array.go b/go/parquet/internal/encoding/delta_length_byte_array.go index 3563ccec4617c..8a8dd00ae12e7 100644 --- a/go/parquet/internal/encoding/delta_length_byte_array.go +++ b/go/parquet/internal/encoding/delta_length_byte_array.go @@ -71,9 +71,13 @@ func (DeltaLengthByteArrayEncoder) Type() parquet.Type { return parquet.Types.ByteArray } -// FlushValues flushes any remaining data and returns the final encoded buffer of data. -func (enc *DeltaLengthByteArrayEncoder) FlushValues() Buffer { - ret := enc.lengthEncoder.FlushValues() +// FlushValues flushes any remaining data and returns the final encoded buffer of data +// or returns nil and any error encountered. +func (enc *DeltaLengthByteArrayEncoder) FlushValues() (Buffer, error) { + ret, err := enc.lengthEncoder.FlushValues() + if err != nil { + return nil, err + } defer ret.Release() data := enc.sink.Finish() @@ -83,7 +87,7 @@ func (enc *DeltaLengthByteArrayEncoder) FlushValues() Buffer { output.ResizeNoShrink(ret.Len() + data.Len()) copy(output.Bytes(), ret.Bytes()) copy(output.Bytes()[ret.Len():], data.Bytes()) - return poolBuffer{output} + return poolBuffer{output}, nil } // DeltaLengthByteArrayDecoder is a decoder for handling data produced by the corresponding diff --git a/go/parquet/internal/encoding/encoder.go b/go/parquet/internal/encoding/encoder.go index 49072c8e1519e..2f9a40cc0d9a0 100644 --- a/go/parquet/internal/encoding/encoder.go +++ b/go/parquet/internal/encoding/encoder.go @@ -99,7 +99,7 @@ func (e *encoder) append(data []byte) { e.sink.Write(data) } // FlushValues flushes any unwritten data to the buffer and returns the finished encoded buffer of data. // This also clears the encoder, ownership of the data belongs to whomever called FlushValues, Release // should be called on the resulting Buffer when done. -func (e *encoder) FlushValues() Buffer { return e.sink.Finish() } +func (e *encoder) FlushValues() (Buffer, error) { return e.sink.Finish(), nil } // Bytes returns the current bytes that have been written to the encoder's buffer but doesn't transfer ownership. func (e *encoder) Bytes() []byte { return e.sink.Bytes() } @@ -147,13 +147,17 @@ func (d *dictEncoder) addIndex(idx int) { } // FlushValues dumps all the currently buffered indexes that would become the data page to a buffer and -// returns it. -func (d *dictEncoder) FlushValues() Buffer { +// returns it or returns nil and any error encountered. +func (d *dictEncoder) FlushValues() (Buffer, error) { buf := bufferPool.Get().(*memory.Buffer) buf.Reserve(int(d.EstimatedDataEncodedSize())) - size := d.WriteIndices(buf.Buf()) + size, err := d.WriteIndices(buf.Buf()) + if err != nil { + poolBuffer{buf}.Release() + return nil, err + } buf.ResizeNoShrink(size) - return poolBuffer{buf} + return poolBuffer{buf}, nil } // EstimatedDataEncodedSize returns the maximum number of bytes needed to store the RLE encoded indexes, not including the @@ -187,19 +191,20 @@ func (d *dictEncoder) WriteDict(out []byte) { // WriteIndices performs Run Length encoding on the indexes and the writes the encoded // index value data to the provided byte slice, returning the number of bytes actually written. -func (d *dictEncoder) WriteIndices(out []byte) int { +// If any error is encountered, it will return -1 and the error. +func (d *dictEncoder) WriteIndices(out []byte) (int, error) { out[0] = byte(d.BitWidth()) enc := utils.NewRleEncoder(utils.NewWriterAtBuffer(out[1:]), d.BitWidth()) for _, idx := range d.idxValues { - if !enc.Put(uint64(idx)) { - return -1 + if err := enc.Put(uint64(idx)); err != nil { + return -1, err } } nbytes := enc.Flush() d.idxValues = d.idxValues[:0] - return nbytes + 1 + return nbytes + 1, nil } // Put adds a value to the dictionary data column, inserting the value if it diff --git a/go/parquet/internal/encoding/encoding_benchmarks_test.go b/go/parquet/internal/encoding/encoding_benchmarks_test.go new file mode 100644 index 0000000000000..f13d3d02187e3 --- /dev/null +++ b/go/parquet/internal/encoding/encoding_benchmarks_test.go @@ -0,0 +1,466 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package encoding_test + +import ( + "fmt" + "math" + "testing" + + "github.com/apache/arrow/go/arrow" + "github.com/apache/arrow/go/arrow/array" + "github.com/apache/arrow/go/arrow/memory" + "github.com/apache/arrow/go/parquet" + "github.com/apache/arrow/go/parquet/internal/encoding" + "github.com/apache/arrow/go/parquet/internal/hashing" + "github.com/apache/arrow/go/parquet/internal/testutils" + "github.com/apache/arrow/go/parquet/schema" +) + +const ( + MINSIZE = 1024 + MAXSIZE = 65536 +) + +func BenchmarkPlainEncodingBoolean(b *testing.B) { + for sz := MINSIZE; sz < MAXSIZE+1; sz *= 2 { + b.Run(fmt.Sprintf("len %d", sz), func(b *testing.B) { + values := make([]bool, sz) + for idx := range values { + values[idx] = true + } + encoder := encoding.NewEncoder(parquet.Types.Boolean, parquet.Encodings.Plain, + false, nil, memory.DefaultAllocator).(encoding.BooleanEncoder) + b.ResetTimer() + b.SetBytes(int64(len(values))) + for n := 0; n < b.N; n++ { + encoder.Put(values) + buf, _ := encoder.FlushValues() + buf.Release() + } + }) + } +} + +func BenchmarkPlainEncodingInt32(b *testing.B) { + for sz := MINSIZE; sz < MAXSIZE+1; sz *= 2 { + b.Run(fmt.Sprintf("len %d", sz), func(b *testing.B) { + values := make([]int32, sz) + for idx := range values { + values[idx] = 64 + } + encoder := encoding.NewEncoder(parquet.Types.Int32, parquet.Encodings.Plain, + false, nil, memory.DefaultAllocator).(encoding.Int32Encoder) + b.ResetTimer() + b.SetBytes(int64(len(values) * arrow.Int32SizeBytes)) + for n := 0; n < b.N; n++ { + encoder.Put(values) + buf, _ := encoder.FlushValues() + buf.Release() + } + }) + } +} + +func BenchmarkPlainEncodingInt64(b *testing.B) { + for sz := MINSIZE; sz < MAXSIZE+1; sz *= 2 { + b.Run(fmt.Sprintf("len %d", sz), func(b *testing.B) { + values := make([]int64, sz) + for idx := range values { + values[idx] = 64 + } + encoder := encoding.NewEncoder(parquet.Types.Int64, parquet.Encodings.Plain, + false, nil, memory.DefaultAllocator).(encoding.Int64Encoder) + b.ResetTimer() + b.SetBytes(int64(len(values) * arrow.Int64SizeBytes)) + for n := 0; n < b.N; n++ { + encoder.Put(values) + buf, _ := encoder.FlushValues() + buf.Release() + } + }) + } +} + +func BenchmarkPlainEncodingFloat32(b *testing.B) { + for sz := MINSIZE; sz < MAXSIZE+1; sz *= 2 { + b.Run(fmt.Sprintf("len %d", sz), func(b *testing.B) { + values := make([]float32, sz) + for idx := range values { + values[idx] = 64.0 + } + encoder := encoding.NewEncoder(parquet.Types.Float, parquet.Encodings.Plain, + false, nil, memory.DefaultAllocator).(encoding.Float32Encoder) + b.ResetTimer() + b.SetBytes(int64(len(values) * arrow.Float32SizeBytes)) + for n := 0; n < b.N; n++ { + encoder.Put(values) + buf, _ := encoder.FlushValues() + buf.Release() + } + }) + } +} + +func BenchmarkPlainEncodingFloat64(b *testing.B) { + for sz := MINSIZE; sz < MAXSIZE+1; sz *= 2 { + b.Run(fmt.Sprintf("len %d", sz), func(b *testing.B) { + values := make([]float64, sz) + for idx := range values { + values[idx] = 64 + } + encoder := encoding.NewEncoder(parquet.Types.Double, parquet.Encodings.Plain, + false, nil, memory.DefaultAllocator).(encoding.Float64Encoder) + b.ResetTimer() + b.SetBytes(int64(len(values) * arrow.Float64SizeBytes)) + for n := 0; n < b.N; n++ { + encoder.Put(values) + buf, _ := encoder.FlushValues() + buf.Release() + } + }) + } +} + +func BenchmarkPlainDecodingBoolean(b *testing.B) { + for sz := MINSIZE; sz < MAXSIZE+1; sz *= 2 { + b.Run(fmt.Sprintf("len %d", sz), func(b *testing.B) { + output := make([]bool, sz) + values := make([]bool, sz) + for idx := range values { + values[idx] = true + } + encoder := encoding.NewEncoder(parquet.Types.Boolean, parquet.Encodings.Plain, + false, nil, memory.DefaultAllocator).(encoding.BooleanEncoder) + encoder.Put(values) + buf, _ := encoder.FlushValues() + defer buf.Release() + + decoder := encoding.NewDecoder(parquet.Types.Boolean, parquet.Encodings.Plain, nil, memory.DefaultAllocator) + b.ResetTimer() + b.SetBytes(int64(len(values))) + for n := 0; n < b.N; n++ { + decoder.SetData(sz, buf.Bytes()) + decoder.(encoding.BooleanDecoder).Decode(output) + } + }) + } +} + +func BenchmarkPlainDecodingInt32(b *testing.B) { + for sz := MINSIZE; sz < MAXSIZE+1; sz *= 2 { + b.Run(fmt.Sprintf("len %d", sz), func(b *testing.B) { + output := make([]int32, sz) + values := make([]int32, sz) + for idx := range values { + values[idx] = 64 + } + encoder := encoding.NewEncoder(parquet.Types.Int32, parquet.Encodings.Plain, + false, nil, memory.DefaultAllocator).(encoding.Int32Encoder) + encoder.Put(values) + buf, _ := encoder.FlushValues() + defer buf.Release() + + decoder := encoding.NewDecoder(parquet.Types.Int32, parquet.Encodings.Plain, nil, memory.DefaultAllocator) + b.ResetTimer() + b.SetBytes(int64(len(values))) + for n := 0; n < b.N; n++ { + decoder.SetData(sz, buf.Bytes()) + decoder.(encoding.Int32Decoder).Decode(output) + } + }) + } +} + +func BenchmarkMemoTableFloat64(b *testing.B) { + tests := []struct { + nunique int32 + nvalues int64 + }{ + {100, 65535}, + {1000, 65535}, + {5000, 65535}, + } + + for _, tt := range tests { + b.Run(fmt.Sprintf("%d unique n %d", tt.nunique, tt.nvalues), func(b *testing.B) { + rag := testutils.NewRandomArrayGenerator(0) + dict := rag.Float64(int64(tt.nunique), 0) + indices := rag.Int32(tt.nvalues, 0, int32(tt.nunique)-1, 0) + + values := make([]float64, tt.nvalues) + for idx := range values { + values[idx] = dict.Value(int(indices.Value(idx))) + } + + b.ResetTimer() + b.Run("go map", func(b *testing.B) { + for i := 0; i < b.N; i++ { + tbl := encoding.NewFloat64MemoTable(memory.DefaultAllocator) + for _, v := range values { + tbl.GetOrInsert(v) + } + if tbl.Size() != int(tt.nunique) { + b.Fatal(tbl.Size(), tt.nunique) + } + } + }) + b.ResetTimer() + b.Run("xxh3", func(b *testing.B) { + for i := 0; i < b.N; i++ { + tbl := hashing.NewFloat64MemoTable(0) + for _, v := range values { + tbl.GetOrInsert(v) + } + if tbl.Size() != int(tt.nunique) { + b.Fatal(tbl.Size(), tt.nunique) + } + } + }) + }) + } +} + +func BenchmarkMemoTableInt32(b *testing.B) { + tests := []struct { + nunique int32 + nvalues int64 + }{ + {100, 65535}, + {1000, 65535}, + {5000, 65535}, + } + + for _, tt := range tests { + b.Run(fmt.Sprintf("%d unique n %d", tt.nunique, tt.nvalues), func(b *testing.B) { + rag := testutils.NewRandomArrayGenerator(0) + dict := rag.Int32(int64(tt.nunique), 0, math.MaxInt32-1, 0) + indices := rag.Int32(tt.nvalues, 0, int32(tt.nunique)-1, 0) + + values := make([]int32, tt.nvalues) + for idx := range values { + values[idx] = dict.Value(int(indices.Value(idx))) + } + b.ResetTimer() + b.Run("xxh3", func(b *testing.B) { + for i := 0; i < b.N; i++ { + tbl := hashing.NewInt32MemoTable(0) + for _, v := range values { + tbl.GetOrInsert(v) + } + if tbl.Size() != int(tt.nunique) { + b.Fatal(tbl.Size(), tt.nunique) + } + } + }) + + b.Run("go map", func(b *testing.B) { + for i := 0; i < b.N; i++ { + tbl := encoding.NewInt32MemoTable(memory.DefaultAllocator) + for _, v := range values { + tbl.GetOrInsert(v) + } + if tbl.Size() != int(tt.nunique) { + b.Fatal(tbl.Size(), tt.nunique) + } + } + }) + }) + } +} + +func BenchmarkMemoTable(b *testing.B) { + tests := []struct { + nunique int32 + minLen int32 + maxLen int32 + nvalues int64 + }{ + {100, 32, 32, 65535}, + {100, 8, 32, 65535}, + {1000, 32, 32, 65535}, + {1000, 8, 32, 65535}, + {5000, 32, 32, 65535}, + {5000, 8, 32, 65535}, + } + + for _, tt := range tests { + b.Run(fmt.Sprintf("%d unique len %d-%d n %d", tt.nunique, tt.minLen, tt.maxLen, tt.nvalues), func(b *testing.B) { + + rag := testutils.NewRandomArrayGenerator(0) + dict := rag.ByteArray(int64(tt.nunique), tt.minLen, tt.maxLen, 0).(*array.String) + indices := rag.Int32(tt.nvalues, 0, int32(tt.nunique)-1, 0) + + values := make([]parquet.ByteArray, tt.nvalues) + for idx := range values { + values[idx] = []byte(dict.Value(int(indices.Value(idx)))) + } + + b.ResetTimer() + + b.Run("xxh3", func(b *testing.B) { + for i := 0; i < b.N; i++ { + tbl := hashing.NewBinaryMemoTable(memory.DefaultAllocator, 0, -1) + for _, v := range values { + tbl.GetOrInsert(v) + } + if tbl.Size() != int(tt.nunique) { + b.Fatal(tbl.Size(), tt.nunique) + } + tbl.Release() + } + }) + b.ResetTimer() + b.Run("go map", func(b *testing.B) { + for i := 0; i < b.N; i++ { + tbl := encoding.NewBinaryMemoTable(memory.DefaultAllocator) + for _, v := range values { + tbl.GetOrInsert(v) + } + if tbl.Size() != int(tt.nunique) { + b.Fatal(tbl.Size(), tt.nunique) + } + tbl.Release() + } + }) + }) + } +} + +func BenchmarkMemoTableAllUnique(b *testing.B) { + tests := []struct { + minLen int32 + maxLen int32 + nvalues int64 + }{ + {32, 32, 1024}, + {8, 32, 1024}, + {32, 32, 32767}, + {8, 32, 32767}, + {32, 32, 65535}, + {8, 32, 65535}, + } + for _, tt := range tests { + b.Run(fmt.Sprintf("values %d len %d-%d", tt.nvalues, tt.minLen, tt.maxLen), func(b *testing.B) { + + rag := testutils.NewRandomArrayGenerator(0) + dict := rag.ByteArray(tt.nvalues, tt.minLen, tt.maxLen, 0).(*array.String) + + values := make([]parquet.ByteArray, tt.nvalues) + for idx := range values { + values[idx] = []byte(dict.Value(idx)) + } + + b.ResetTimer() + b.Run("go map", func(b *testing.B) { + for i := 0; i < b.N; i++ { + tbl := encoding.NewBinaryMemoTable(memory.DefaultAllocator) + for _, v := range values { + tbl.GetOrInsert(v) + } + if tbl.Size() != int(tt.nvalues) { + b.Fatal(tbl.Size(), tt.nvalues) + } + tbl.Release() + } + }) + + b.Run("xxh3", func(b *testing.B) { + for i := 0; i < b.N; i++ { + tbl := hashing.NewBinaryMemoTable(memory.DefaultAllocator, 0, -1) + for _, v := range values { + tbl.GetOrInsert(v) + } + if tbl.Size() != int(tt.nvalues) { + b.Fatal(tbl.Size(), tt.nvalues) + } + tbl.Release() + } + }) + }) + } + +} + +func BenchmarkEncodeDictByteArray(b *testing.B) { + const ( + nunique = 100 + minLen = 8 + maxLen = 32 + nvalues = 65535 + ) + + rag := testutils.NewRandomArrayGenerator(0) + dict := rag.ByteArray(nunique, minLen, maxLen, 0).(*array.String) + indices := rag.Int32(nvalues, 0, nunique-1, 0) + + values := make([]parquet.ByteArray, nvalues) + for idx := range values { + values[idx] = []byte(dict.Value(int(indices.Value(idx)))) + } + col := schema.NewColumn(schema.NewByteArrayNode("bytearray", parquet.Repetitions.Required, -1), 0, 0) + + out := make([]byte, nunique*(maxLen+arrow.Uint32SizeBytes)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + enc := encoding.NewEncoder(parquet.Types.ByteArray, parquet.Encodings.PlainDict, true, col, memory.DefaultAllocator).(*encoding.DictByteArrayEncoder) + enc.Put(values) + enc.WriteDict(out) + } +} + +func BenchmarkDecodeDictByteArray(b *testing.B) { + const ( + nunique = 100 + minLen = 32 + maxLen = 32 + nvalues = 65535 + ) + + rag := testutils.NewRandomArrayGenerator(0) + dict := rag.ByteArray(nunique, minLen, maxLen, 0).(*array.String) + indices := rag.Int32(nvalues, 0, nunique-1, 0) + + values := make([]parquet.ByteArray, nvalues) + for idx := range values { + values[idx] = []byte(dict.Value(int(indices.Value(idx)))) + } + + col := schema.NewColumn(schema.NewByteArrayNode("bytearray", parquet.Repetitions.Required, -1), 0, 0) + enc := encoding.NewEncoder(parquet.Types.ByteArray, parquet.Encodings.PlainDict, true, col, memory.DefaultAllocator).(*encoding.DictByteArrayEncoder) + enc.Put(values) + + dictBuf := make([]byte, enc.DictEncodedSize()) + enc.WriteDict(dictBuf) + + idxBuf := make([]byte, enc.EstimatedDataEncodedSize()) + enc.WriteIndices(idxBuf) + + out := make([]parquet.ByteArray, nvalues) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + dec := encoding.NewDecoder(parquet.Types.ByteArray, parquet.Encodings.Plain, col, memory.DefaultAllocator) + dec.SetData(nunique, dictBuf) + dictDec := encoding.NewDictDecoder(parquet.Types.ByteArray, col, memory.DefaultAllocator).(*encoding.DictByteArrayDecoder) + dictDec.SetDict(dec) + dictDec.SetData(nvalues, idxBuf) + + dictDec.Decode(out) + } +} diff --git a/go/parquet/internal/encoding/encoding_test.go b/go/parquet/internal/encoding/encoding_test.go new file mode 100644 index 0000000000000..b58a13c199184 --- /dev/null +++ b/go/parquet/internal/encoding/encoding_test.go @@ -0,0 +1,684 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package encoding_test + +import ( + "fmt" + "reflect" + "testing" + "unsafe" + + "github.com/apache/arrow/go/arrow" + "github.com/apache/arrow/go/arrow/bitutil" + "github.com/apache/arrow/go/arrow/memory" + "github.com/apache/arrow/go/parquet" + "github.com/apache/arrow/go/parquet/internal/encoding" + "github.com/apache/arrow/go/parquet/internal/testutils" + "github.com/apache/arrow/go/parquet/schema" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type nodeFactory func(string, parquet.Repetition, int32) *schema.PrimitiveNode + +func createNodeFactory(t reflect.Type) nodeFactory { + switch t { + case reflect.TypeOf(true): + return schema.NewBooleanNode + case reflect.TypeOf(int32(0)): + return schema.NewInt32Node + case reflect.TypeOf(int64(0)): + return schema.NewInt64Node + case reflect.TypeOf(parquet.Int96{}): + return schema.NewInt96Node + case reflect.TypeOf(float32(0)): + return schema.NewFloat32Node + case reflect.TypeOf(float64(0)): + return schema.NewFloat64Node + case reflect.TypeOf(parquet.ByteArray{}): + return schema.NewByteArrayNode + case reflect.TypeOf(parquet.FixedLenByteArray{}): + return func(name string, rep parquet.Repetition, field int32) *schema.PrimitiveNode { + return schema.NewFixedLenByteArrayNode(name, rep, 12, field) + } + } + return nil +} + +func initdata(t reflect.Type, drawbuf, decodebuf []byte, nvals, repeats int, heap *memory.Buffer) (interface{}, interface{}) { + switch t { + case reflect.TypeOf(true): + draws := *(*[]bool)(unsafe.Pointer(&drawbuf)) + decode := *(*[]bool)(unsafe.Pointer(&decodebuf)) + testutils.InitValues(draws[:nvals], heap) + + for j := 1; j < repeats; j++ { + for k := 0; k < nvals; k++ { + draws[nvals*j+k] = draws[k] + } + } + + return draws[:nvals*repeats], decode[:nvals*repeats] + case reflect.TypeOf(int32(0)): + draws := arrow.Int32Traits.CastFromBytes(drawbuf) + decode := arrow.Int32Traits.CastFromBytes(decodebuf) + testutils.InitValues(draws[:nvals], heap) + + for j := 1; j < repeats; j++ { + for k := 0; k < nvals; k++ { + draws[nvals*j+k] = draws[k] + } + } + + return draws[:nvals*repeats], decode[:nvals*repeats] + case reflect.TypeOf(int64(0)): + draws := arrow.Int64Traits.CastFromBytes(drawbuf) + decode := arrow.Int64Traits.CastFromBytes(decodebuf) + testutils.InitValues(draws[:nvals], heap) + + for j := 1; j < repeats; j++ { + for k := 0; k < nvals; k++ { + draws[nvals*j+k] = draws[k] + } + } + + return draws[:nvals*repeats], decode[:nvals*repeats] + case reflect.TypeOf(parquet.Int96{}): + draws := parquet.Int96Traits.CastFromBytes(drawbuf) + decode := parquet.Int96Traits.CastFromBytes(decodebuf) + testutils.InitValues(draws[:nvals], heap) + + for j := 1; j < repeats; j++ { + for k := 0; k < nvals; k++ { + draws[nvals*j+k] = draws[k] + } + } + + return draws[:nvals*repeats], decode[:nvals*repeats] + case reflect.TypeOf(float32(0)): + draws := arrow.Float32Traits.CastFromBytes(drawbuf) + decode := arrow.Float32Traits.CastFromBytes(decodebuf) + testutils.InitValues(draws[:nvals], heap) + + for j := 1; j < repeats; j++ { + for k := 0; k < nvals; k++ { + draws[nvals*j+k] = draws[k] + } + } + + return draws[:nvals*repeats], decode[:nvals*repeats] + case reflect.TypeOf(float64(0)): + draws := arrow.Float64Traits.CastFromBytes(drawbuf) + decode := arrow.Float64Traits.CastFromBytes(decodebuf) + testutils.InitValues(draws[:nvals], heap) + + for j := 1; j < repeats; j++ { + for k := 0; k < nvals; k++ { + draws[nvals*j+k] = draws[k] + } + } + + return draws[:nvals*repeats], decode[:nvals*repeats] + case reflect.TypeOf(parquet.ByteArray{}): + draws := make([]parquet.ByteArray, nvals*repeats) + decode := make([]parquet.ByteArray, nvals*repeats) + testutils.InitValues(draws[:nvals], heap) + + for j := 1; j < repeats; j++ { + for k := 0; k < nvals; k++ { + draws[nvals*j+k] = draws[k] + } + } + + return draws[:nvals*repeats], decode[:nvals*repeats] + case reflect.TypeOf(parquet.FixedLenByteArray{}): + draws := make([]parquet.FixedLenByteArray, nvals*repeats) + decode := make([]parquet.FixedLenByteArray, nvals*repeats) + testutils.InitValues(draws[:nvals], heap) + + for j := 1; j < repeats; j++ { + for k := 0; k < nvals; k++ { + draws[nvals*j+k] = draws[k] + } + } + + return draws[:nvals*repeats], decode[:nvals*repeats] + } + return nil, nil +} + +func encode(enc encoding.TypedEncoder, vals interface{}) { + switch v := vals.(type) { + case []bool: + enc.(encoding.BooleanEncoder).Put(v) + case []int32: + enc.(encoding.Int32Encoder).Put(v) + case []int64: + enc.(encoding.Int64Encoder).Put(v) + case []parquet.Int96: + enc.(encoding.Int96Encoder).Put(v) + case []float32: + enc.(encoding.Float32Encoder).Put(v) + case []float64: + enc.(encoding.Float64Encoder).Put(v) + case []parquet.ByteArray: + enc.(encoding.ByteArrayEncoder).Put(v) + case []parquet.FixedLenByteArray: + enc.(encoding.FixedLenByteArrayEncoder).Put(v) + } +} + +func encodeSpaced(enc encoding.TypedEncoder, vals interface{}, validBits []byte, validBitsOffset int64) { + switch v := vals.(type) { + case []bool: + enc.(encoding.BooleanEncoder).PutSpaced(v, validBits, validBitsOffset) + case []int32: + enc.(encoding.Int32Encoder).PutSpaced(v, validBits, validBitsOffset) + case []int64: + enc.(encoding.Int64Encoder).PutSpaced(v, validBits, validBitsOffset) + case []parquet.Int96: + enc.(encoding.Int96Encoder).PutSpaced(v, validBits, validBitsOffset) + case []float32: + enc.(encoding.Float32Encoder).PutSpaced(v, validBits, validBitsOffset) + case []float64: + enc.(encoding.Float64Encoder).PutSpaced(v, validBits, validBitsOffset) + case []parquet.ByteArray: + enc.(encoding.ByteArrayEncoder).PutSpaced(v, validBits, validBitsOffset) + case []parquet.FixedLenByteArray: + enc.(encoding.FixedLenByteArrayEncoder).PutSpaced(v, validBits, validBitsOffset) + } +} + +func decode(dec encoding.TypedDecoder, out interface{}) (int, error) { + switch v := out.(type) { + case []bool: + return dec.(encoding.BooleanDecoder).Decode(v) + case []int32: + return dec.(encoding.Int32Decoder).Decode(v) + case []int64: + return dec.(encoding.Int64Decoder).Decode(v) + case []parquet.Int96: + return dec.(encoding.Int96Decoder).Decode(v) + case []float32: + return dec.(encoding.Float32Decoder).Decode(v) + case []float64: + return dec.(encoding.Float64Decoder).Decode(v) + case []parquet.ByteArray: + return dec.(encoding.ByteArrayDecoder).Decode(v) + case []parquet.FixedLenByteArray: + return dec.(encoding.FixedLenByteArrayDecoder).Decode(v) + } + return 0, nil +} + +func decodeSpaced(dec encoding.TypedDecoder, out interface{}, nullCount int, validBits []byte, validBitsOffset int64) (int, error) { + switch v := out.(type) { + case []bool: + return dec.(encoding.BooleanDecoder).DecodeSpaced(v, nullCount, validBits, validBitsOffset) + case []int32: + return dec.(encoding.Int32Decoder).DecodeSpaced(v, nullCount, validBits, validBitsOffset) + case []int64: + return dec.(encoding.Int64Decoder).DecodeSpaced(v, nullCount, validBits, validBitsOffset) + case []parquet.Int96: + return dec.(encoding.Int96Decoder).DecodeSpaced(v, nullCount, validBits, validBitsOffset) + case []float32: + return dec.(encoding.Float32Decoder).DecodeSpaced(v, nullCount, validBits, validBitsOffset) + case []float64: + return dec.(encoding.Float64Decoder).DecodeSpaced(v, nullCount, validBits, validBitsOffset) + case []parquet.ByteArray: + return dec.(encoding.ByteArrayDecoder).DecodeSpaced(v, nullCount, validBits, validBitsOffset) + case []parquet.FixedLenByteArray: + return dec.(encoding.FixedLenByteArrayDecoder).DecodeSpaced(v, nullCount, validBits, validBitsOffset) + } + return 0, nil +} + +type BaseEncodingTestSuite struct { + suite.Suite + + descr *schema.Column + typeLen int + mem memory.Allocator + typ reflect.Type + + nvalues int + heap *memory.Buffer + inputBytes *memory.Buffer + outputBytes *memory.Buffer + nodeFactory nodeFactory + + draws interface{} + decodeBuf interface{} +} + +func (b *BaseEncodingTestSuite) SetupSuite() { + b.mem = memory.DefaultAllocator + b.inputBytes = memory.NewResizableBuffer(b.mem) + b.outputBytes = memory.NewResizableBuffer(b.mem) + b.heap = memory.NewResizableBuffer(b.mem) + b.nodeFactory = createNodeFactory(b.typ) +} + +func (b *BaseEncodingTestSuite) TearDownSuite() { + b.inputBytes.Release() + b.outputBytes.Release() + b.heap.Release() +} + +func (b *BaseEncodingTestSuite) SetupTest() { + b.descr = schema.NewColumn(b.nodeFactory("name", parquet.Repetitions.Optional, -1), 0, 0) + b.typeLen = int(b.descr.TypeLength()) +} + +func (b *BaseEncodingTestSuite) initData(nvalues, repeats int) { + b.nvalues = nvalues * repeats + b.inputBytes.ResizeNoShrink(b.nvalues * int(b.typ.Size())) + b.outputBytes.ResizeNoShrink(b.nvalues * int(b.typ.Size())) + memory.Set(b.inputBytes.Buf(), 0) + memory.Set(b.outputBytes.Buf(), 0) + + b.draws, b.decodeBuf = initdata(b.typ, b.inputBytes.Buf(), b.outputBytes.Buf(), nvalues, repeats, b.heap) +} + +func (b *BaseEncodingTestSuite) encodeTestData(e parquet.Encoding) (encoding.Buffer, error) { + enc := encoding.NewEncoder(testutils.TypeToParquetType(b.typ), e, false, b.descr, memory.DefaultAllocator) + b.Equal(e, enc.Encoding()) + b.Equal(b.descr.PhysicalType(), enc.Type()) + encode(enc, reflect.ValueOf(b.draws).Slice(0, b.nvalues).Interface()) + return enc.FlushValues() +} + +func (b *BaseEncodingTestSuite) decodeTestData(e parquet.Encoding, buf []byte) { + dec := encoding.NewDecoder(testutils.TypeToParquetType(b.typ), e, b.descr, b.mem) + b.Equal(e, dec.Encoding()) + b.Equal(b.descr.PhysicalType(), dec.Type()) + + dec.SetData(b.nvalues, buf) + decoded, _ := decode(dec, b.decodeBuf) + b.Equal(b.nvalues, decoded) + b.Equal(reflect.ValueOf(b.draws).Slice(0, b.nvalues).Interface(), reflect.ValueOf(b.decodeBuf).Slice(0, b.nvalues).Interface()) +} + +func (b *BaseEncodingTestSuite) encodeTestDataSpaced(e parquet.Encoding, validBits []byte, validBitsOffset int64) (encoding.Buffer, error) { + enc := encoding.NewEncoder(testutils.TypeToParquetType(b.typ), e, false, b.descr, memory.DefaultAllocator) + encodeSpaced(enc, reflect.ValueOf(b.draws).Slice(0, b.nvalues).Interface(), validBits, validBitsOffset) + return enc.FlushValues() +} + +func (b *BaseEncodingTestSuite) decodeTestDataSpaced(e parquet.Encoding, nullCount int, buf []byte, validBits []byte, validBitsOffset int64) { + dec := encoding.NewDecoder(testutils.TypeToParquetType(b.typ), e, b.descr, b.mem) + dec.SetData(b.nvalues-nullCount, buf) + decoded, _ := decodeSpaced(dec, b.decodeBuf, nullCount, validBits, validBitsOffset) + b.Equal(b.nvalues, decoded) + + drawval := reflect.ValueOf(b.draws) + decodeval := reflect.ValueOf(b.decodeBuf) + for j := 0; j < b.nvalues; j++ { + if bitutil.BitIsSet(validBits, int(validBitsOffset)+j) { + b.Equal(drawval.Index(j).Interface(), decodeval.Index(j).Interface()) + } + } +} + +func (b *BaseEncodingTestSuite) checkRoundTrip(e parquet.Encoding) { + buf, _ := b.encodeTestData(e) + defer buf.Release() + b.decodeTestData(e, buf.Bytes()) +} + +func (b *BaseEncodingTestSuite) checkRoundTripSpaced(e parquet.Encoding, validBits []byte, validBitsOffset int64) { + buf, _ := b.encodeTestDataSpaced(e, validBits, validBitsOffset) + defer buf.Release() + + nullCount := 0 + for i := 0; i < b.nvalues; i++ { + if bitutil.BitIsNotSet(validBits, int(validBitsOffset)+i) { + nullCount++ + } + } + b.decodeTestDataSpaced(e, nullCount, buf.Bytes(), validBits, validBitsOffset) +} + +func (b *BaseEncodingTestSuite) TestBasicRoundTrip() { + b.initData(10000, 1) + b.checkRoundTrip(parquet.Encodings.Plain) +} + +func (b *BaseEncodingTestSuite) TestDeltaEncodingRoundTrip() { + b.initData(10000, 1) + + switch b.typ { + case reflect.TypeOf(int32(0)), reflect.TypeOf(int64(0)): + b.checkRoundTrip(parquet.Encodings.DeltaBinaryPacked) + default: + b.Panics(func() { b.checkRoundTrip(parquet.Encodings.DeltaBinaryPacked) }) + } +} + +func (b *BaseEncodingTestSuite) TestDeltaLengthByteArrayRoundTrip() { + b.initData(10000, 1) + + switch b.typ { + case reflect.TypeOf(parquet.ByteArray{}): + b.checkRoundTrip(parquet.Encodings.DeltaLengthByteArray) + default: + b.Panics(func() { b.checkRoundTrip(parquet.Encodings.DeltaLengthByteArray) }) + } +} + +func (b *BaseEncodingTestSuite) TestDeltaByteArrayRoundTrip() { + b.initData(10000, 1) + + switch b.typ { + case reflect.TypeOf(parquet.ByteArray{}): + b.checkRoundTrip(parquet.Encodings.DeltaByteArray) + default: + b.Panics(func() { b.checkRoundTrip(parquet.Encodings.DeltaLengthByteArray) }) + } +} + +func (b *BaseEncodingTestSuite) TestSpacedRoundTrip() { + exec := func(vals, repeats int, validBitsOffset int64, nullProb float64) { + b.Run(fmt.Sprintf("%d vals %d repeats %d offset %0.3f null", vals, repeats, validBitsOffset, 1-nullProb), func() { + b.initData(vals, repeats) + + size := int64(b.nvalues) + validBitsOffset + r := testutils.NewRandomArrayGenerator(1923) + arr := r.Uint8(size, 0, 100, 1-nullProb) + validBits := arr.NullBitmapBytes() + if validBits != nil { + b.checkRoundTripSpaced(parquet.Encodings.Plain, validBits, validBitsOffset) + switch b.typ { + case reflect.TypeOf(int32(0)), reflect.TypeOf(int64(0)): + b.checkRoundTripSpaced(parquet.Encodings.DeltaBinaryPacked, validBits, validBitsOffset) + case reflect.TypeOf(parquet.ByteArray{}): + b.checkRoundTripSpaced(parquet.Encodings.DeltaLengthByteArray, validBits, validBitsOffset) + b.checkRoundTripSpaced(parquet.Encodings.DeltaByteArray, validBits, validBitsOffset) + } + } + }) + } + + const ( + avx512Size = 64 + simdSize = avx512Size + multiSimdSize = simdSize * 33 + ) + + for _, nullProb := range []float64{0.001, 0.1, 0.5, 0.9, 0.999} { + // Test with both size and offset up to 3 simd block + for i := 1; i < simdSize*3; i++ { + exec(i, 1, 0, nullProb) + exec(i, 1, int64(i+1), nullProb) + } + // large block and offset + exec(multiSimdSize, 1, 0, nullProb) + exec(multiSimdSize+33, 1, 0, nullProb) + exec(multiSimdSize, 1, 33, nullProb) + exec(multiSimdSize+33, 1, 33, nullProb) + } +} + +func TestEncoding(t *testing.T) { + tests := []struct { + name string + typ reflect.Type + }{ + {"Bool", reflect.TypeOf(true)}, + {"Int32", reflect.TypeOf(int32(0))}, + {"Int64", reflect.TypeOf(int64(0))}, + {"Float32", reflect.TypeOf(float32(0))}, + {"Float64", reflect.TypeOf(float64(0))}, + {"Int96", reflect.TypeOf(parquet.Int96{})}, + {"ByteArray", reflect.TypeOf(parquet.ByteArray{})}, + {"FixedLenByteArray", reflect.TypeOf(parquet.FixedLenByteArray{})}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + suite.Run(t, &BaseEncodingTestSuite{typ: tt.typ}) + }) + } +} + +type DictionaryEncodingTestSuite struct { + BaseEncodingTestSuite +} + +func (d *DictionaryEncodingTestSuite) encodeTestDataDict(e parquet.Encoding) (dictBuffer, indices encoding.Buffer, numEntries int) { + enc := encoding.NewEncoder(testutils.TypeToParquetType(d.typ), e, true, d.descr, memory.DefaultAllocator).(encoding.DictEncoder) + + d.Equal(parquet.Encodings.PlainDict, enc.Encoding()) + d.Equal(d.descr.PhysicalType(), enc.Type()) + encode(enc, reflect.ValueOf(d.draws).Slice(0, d.nvalues).Interface()) + dictBuffer = memory.NewResizableBuffer(d.mem) + dictBuffer.Resize(enc.DictEncodedSize()) + enc.WriteDict(dictBuffer.Bytes()) + indices, _ = enc.FlushValues() + numEntries = enc.NumEntries() + return +} + +func (d *DictionaryEncodingTestSuite) encodeTestDataDictSpaced(e parquet.Encoding, validBits []byte, validBitsOffset int64) (dictBuffer, indices encoding.Buffer, numEntries int) { + enc := encoding.NewEncoder(testutils.TypeToParquetType(d.typ), e, true, d.descr, memory.DefaultAllocator).(encoding.DictEncoder) + d.Equal(d.descr.PhysicalType(), enc.Type()) + + encodeSpaced(enc, reflect.ValueOf(d.draws).Slice(0, d.nvalues).Interface(), validBits, validBitsOffset) + dictBuffer = memory.NewResizableBuffer(d.mem) + dictBuffer.Resize(enc.DictEncodedSize()) + enc.WriteDict(dictBuffer.Bytes()) + indices, _ = enc.FlushValues() + numEntries = enc.NumEntries() + return +} + +func (d *DictionaryEncodingTestSuite) checkRoundTrip() { + dictBuffer, indices, numEntries := d.encodeTestDataDict(parquet.Encodings.Plain) + defer dictBuffer.Release() + defer indices.Release() + validBits := make([]byte, int(bitutil.BytesForBits(int64(d.nvalues)))+1) + memory.Set(validBits, 255) + + spacedBuffer, indicesSpaced, _ := d.encodeTestDataDictSpaced(parquet.Encodings.Plain, validBits, 0) + defer spacedBuffer.Release() + defer indicesSpaced.Release() + d.Equal(indices.Bytes(), indicesSpaced.Bytes()) + + dictDecoder := encoding.NewDecoder(testutils.TypeToParquetType(d.typ), parquet.Encodings.Plain, d.descr, d.mem) + d.Equal(d.descr.PhysicalType(), dictDecoder.Type()) + dictDecoder.SetData(numEntries, dictBuffer.Bytes()) + decoder := encoding.NewDictDecoder(testutils.TypeToParquetType(d.typ), d.descr, d.mem) + decoder.SetDict(dictDecoder) + decoder.SetData(d.nvalues, indices.Bytes()) + + decoded, _ := decode(decoder, d.decodeBuf) + d.Equal(d.nvalues, decoded) + d.Equal(reflect.ValueOf(d.draws).Slice(0, d.nvalues).Interface(), reflect.ValueOf(d.decodeBuf).Slice(0, d.nvalues).Interface()) + + decoder.SetData(d.nvalues, indices.Bytes()) + decoded, _ = decodeSpaced(decoder, d.decodeBuf, 0, validBits, 0) + d.Equal(d.nvalues, decoded) + d.Equal(reflect.ValueOf(d.draws).Slice(0, d.nvalues).Interface(), reflect.ValueOf(d.decodeBuf).Slice(0, d.nvalues).Interface()) +} + +func (d *DictionaryEncodingTestSuite) TestBasicRoundTrip() { + d.initData(2500, 2) + d.checkRoundTrip() +} + +func TestDictEncoding(t *testing.T) { + tests := []struct { + name string + typ reflect.Type + }{ + {"Int32", reflect.TypeOf(int32(0))}, + {"Int64", reflect.TypeOf(int64(0))}, + {"Float32", reflect.TypeOf(float32(0))}, + {"Float64", reflect.TypeOf(float64(0))}, + {"ByteArray", reflect.TypeOf(parquet.ByteArray{})}, + {"FixedLenByteArray", reflect.TypeOf(parquet.FixedLenByteArray{})}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + suite.Run(t, &DictionaryEncodingTestSuite{BaseEncodingTestSuite{typ: tt.typ}}) + }) + } +} + +func TestWriteDeltaBitPackedInt32(t *testing.T) { + column := schema.NewColumn(schema.NewInt32Node("int32", parquet.Repetitions.Required, -1), 0, 0) + + tests := []struct { + name string + toencode []int32 + expected []byte + }{ + {"simple 12345", []int32{1, 2, 3, 4, 5}, []byte{128, 1, 4, 5, 2, 2, 0, 0, 0, 0}}, + {"odd vals", []int32{7, 5, 3, 1, 2, 3, 4, 5}, []byte{128, 1, 4, 8, 14, 3, 2, 0, 0, 0, 192, 63, 0, 0, 0, 0, 0, 0}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + enc := encoding.NewEncoder(parquet.Types.Int32, parquet.Encodings.DeltaBinaryPacked, false, column, memory.DefaultAllocator) + + enc.(encoding.Int32Encoder).Put(tt.toencode) + buf, _ := enc.FlushValues() + defer buf.Release() + + assert.Equal(t, tt.expected, buf.Bytes()) + + dec := encoding.NewDecoder(parquet.Types.Int32, parquet.Encodings.DeltaBinaryPacked, column, memory.DefaultAllocator) + + dec.(encoding.Int32Decoder).SetData(len(tt.toencode), tt.expected) + out := make([]int32, len(tt.toencode)) + dec.(encoding.Int32Decoder).Decode(out) + assert.Equal(t, tt.toencode, out) + }) + } + + t.Run("test progressive decoding", func(t *testing.T) { + values := make([]int32, 1000) + testutils.FillRandomInt32(0, values) + + enc := encoding.NewEncoder(parquet.Types.Int32, parquet.Encodings.DeltaBinaryPacked, false, column, memory.DefaultAllocator) + enc.(encoding.Int32Encoder).Put(values) + buf, _ := enc.FlushValues() + defer buf.Release() + + dec := encoding.NewDecoder(parquet.Types.Int32, parquet.Encodings.DeltaBinaryPacked, column, memory.DefaultAllocator) + dec.(encoding.Int32Decoder).SetData(len(values), buf.Bytes()) + + valueBuf := make([]int32, 100) + for i, j := 0, len(valueBuf); j <= len(values); i, j = i+len(valueBuf), j+len(valueBuf) { + dec.(encoding.Int32Decoder).Decode(valueBuf) + assert.Equalf(t, values[i:j], valueBuf, "indexes %d:%d", i, j) + } + }) +} + +func TestWriteDeltaBitPackedInt64(t *testing.T) { + column := schema.NewColumn(schema.NewInt64Node("int64", parquet.Repetitions.Required, -1), 0, 0) + + tests := []struct { + name string + toencode []int64 + expected []byte + }{ + {"simple 12345", []int64{1, 2, 3, 4, 5}, []byte{128, 1, 4, 5, 2, 2, 0, 0, 0, 0}}, + {"odd vals", []int64{7, 5, 3, 1, 2, 3, 4, 5}, []byte{128, 1, 4, 8, 14, 3, 2, 0, 0, 0, 192, 63, 0, 0, 0, 0, 0, 0}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + enc := encoding.NewEncoder(parquet.Types.Int64, parquet.Encodings.DeltaBinaryPacked, false, column, memory.DefaultAllocator) + + enc.(encoding.Int64Encoder).Put(tt.toencode) + buf, _ := enc.FlushValues() + defer buf.Release() + + assert.Equal(t, tt.expected, buf.Bytes()) + + dec := encoding.NewDecoder(parquet.Types.Int64, parquet.Encodings.DeltaBinaryPacked, column, memory.DefaultAllocator) + + dec.(encoding.Int64Decoder).SetData(len(tt.toencode), tt.expected) + out := make([]int64, len(tt.toencode)) + dec.(encoding.Int64Decoder).Decode(out) + assert.Equal(t, tt.toencode, out) + }) + } + + t.Run("test progressive decoding", func(t *testing.T) { + values := make([]int64, 1000) + testutils.FillRandomInt64(0, values) + + enc := encoding.NewEncoder(parquet.Types.Int64, parquet.Encodings.DeltaBinaryPacked, false, column, memory.DefaultAllocator) + enc.(encoding.Int64Encoder).Put(values) + buf, _ := enc.FlushValues() + defer buf.Release() + + dec := encoding.NewDecoder(parquet.Types.Int64, parquet.Encodings.DeltaBinaryPacked, column, memory.DefaultAllocator) + dec.(encoding.Int64Decoder).SetData(len(values), buf.Bytes()) + + valueBuf := make([]int64, 100) + for i, j := 0, len(valueBuf); j <= len(values); i, j = i+len(valueBuf), j+len(valueBuf) { + decoded, _ := dec.(encoding.Int64Decoder).Decode(valueBuf) + assert.Equal(t, len(valueBuf), decoded) + assert.Equalf(t, values[i:j], valueBuf, "indexes %d:%d", i, j) + } + }) +} + +func TestDeltaLengthByteArrayEncoding(t *testing.T) { + column := schema.NewColumn(schema.NewByteArrayNode("bytearray", parquet.Repetitions.Required, -1), 0, 0) + + test := []parquet.ByteArray{[]byte("Hello"), []byte("World"), []byte("Foobar"), []byte("ABCDEF")} + expected := []byte{128, 1, 4, 4, 10, 0, 1, 0, 0, 0, 2, 0, 0, 0, 72, 101, 108, 108, 111, 87, 111, 114, 108, 100, 70, 111, 111, 98, 97, 114, 65, 66, 67, 68, 69, 70} + + enc := encoding.NewEncoder(parquet.Types.ByteArray, parquet.Encodings.DeltaLengthByteArray, false, column, memory.DefaultAllocator) + enc.(encoding.ByteArrayEncoder).Put(test) + buf, _ := enc.FlushValues() + defer buf.Release() + + assert.Equal(t, expected, buf.Bytes()) + + dec := encoding.NewDecoder(parquet.Types.ByteArray, parquet.Encodings.DeltaLengthByteArray, column, nil) + dec.SetData(len(test), expected) + out := make([]parquet.ByteArray, len(test)) + decoded, _ := dec.(encoding.ByteArrayDecoder).Decode(out) + assert.Equal(t, len(test), decoded) + assert.Equal(t, test, out) +} + +func TestDeltaByteArrayEncoding(t *testing.T) { + test := []parquet.ByteArray{[]byte("Hello"), []byte("World"), []byte("Foobar"), []byte("ABCDEF")} + expected := []byte{128, 1, 4, 4, 0, 0, 0, 0, 0, 0, 128, 1, 4, 4, 10, 0, 1, 0, 0, 0, 2, 0, 0, 0, 72, 101, 108, 108, 111, 87, 111, 114, 108, 100, 70, 111, 111, 98, 97, 114, 65, 66, 67, 68, 69, 70} + + enc := encoding.NewEncoder(parquet.Types.ByteArray, parquet.Encodings.DeltaByteArray, false, nil, nil) + enc.(encoding.ByteArrayEncoder).Put(test) + buf, _ := enc.FlushValues() + defer buf.Release() + + assert.Equal(t, expected, buf.Bytes()) + + dec := encoding.NewDecoder(parquet.Types.ByteArray, parquet.Encodings.DeltaByteArray, nil, nil) + dec.SetData(len(test), expected) + out := make([]parquet.ByteArray, len(test)) + decoded, _ := dec.(encoding.ByteArrayDecoder).Decode(out) + assert.Equal(t, len(test), decoded) + assert.Equal(t, test, out) +} diff --git a/go/parquet/internal/encoding/levels.go b/go/parquet/internal/encoding/levels.go new file mode 100644 index 0000000000000..29336bad74917 --- /dev/null +++ b/go/parquet/internal/encoding/levels.go @@ -0,0 +1,288 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package encoding + +import ( + "bytes" + "encoding/binary" + "io" + "math/bits" + + "github.com/JohnCGriffin/overflow" + "github.com/apache/arrow/go/arrow/bitutil" + "github.com/apache/arrow/go/parquet" + format "github.com/apache/arrow/go/parquet/internal/gen-go/parquet" + "github.com/apache/arrow/go/parquet/internal/utils" + "golang.org/x/xerrors" +) + +// LevelEncoder is for handling the encoding of Definition and Repetition levels +// to parquet files. +type LevelEncoder struct { + bitWidth int + rleLen int + encoding format.Encoding + rle *utils.RleEncoder + bit *utils.BitWriter +} + +// LevelEncodingMaxBufferSize estimates the max number of bytes needed to encode data with the +// specified encoding given the max level and number of buffered values provided. +func LevelEncodingMaxBufferSize(encoding parquet.Encoding, maxLvl int16, nbuffered int) int { + bitWidth := bits.Len64(uint64(maxLvl)) + nbytes := 0 + switch encoding { + case parquet.Encodings.RLE: + nbytes = utils.MaxBufferSize(bitWidth, nbuffered) + utils.MinBufferSize(bitWidth) + case parquet.Encodings.BitPacked: + nbytes = int(bitutil.BytesForBits(int64(nbuffered * bitWidth))) + default: + panic("parquet: unknown encoding type for levels") + } + return nbytes +} + +// Reset resets the encoder allowing it to be reused and updating the maxlevel to the new +// specified value. +func (l *LevelEncoder) Reset(maxLvl int16) { + l.bitWidth = bits.Len64(uint64(maxLvl)) + switch l.encoding { + case format.Encoding_RLE: + l.rle.Clear() + l.rle.BitWidth = l.bitWidth + case format.Encoding_BIT_PACKED: + l.bit.Clear() + default: + panic("parquet: unknown encoding type") + } +} + +// Init is called to set up the desired encoding type, max level and underlying writer for a +// level encoder to control where the resulting encoded buffer will end up. +func (l *LevelEncoder) Init(encoding parquet.Encoding, maxLvl int16, w io.WriterAt) { + l.bitWidth = bits.Len64(uint64(maxLvl)) + l.encoding = format.Encoding(encoding) + switch l.encoding { + case format.Encoding_RLE: + l.rle = utils.NewRleEncoder(w, l.bitWidth) + case format.Encoding_BIT_PACKED: + l.bit = utils.NewBitWriter(w) + default: + panic("parquet: unknown encoding type for levels") + } +} + +// EncodeNoFlush encodes the provided levels in the encoder, but doesn't flush +// the buffer and return it yet, appending these encoded values. Returns the number +// of values encoded and any error encountered or nil. If err is not nil, nencoded +// will be the number of values encoded before the error was encountered +func (l *LevelEncoder) EncodeNoFlush(lvls []int16) (nencoded int, err error) { + if l.rle == nil && l.bit == nil { + panic("parquet: level encoders are not initialized") + } + + switch l.encoding { + case format.Encoding_RLE: + for _, level := range lvls { + if err = l.rle.Put(uint64(level)); err != nil { + return + } + nencoded++ + } + default: + for _, level := range lvls { + if err = l.bit.WriteValue(uint64(level), uint(l.bitWidth)); err != nil { + return + } + nencoded++ + } + } + return +} + +// Flush flushes out any encoded data to the underlying writer. +func (l *LevelEncoder) Flush() { + if l.rle == nil && l.bit == nil { + panic("parquet: level encoders are not initialized") + } + + switch l.encoding { + case format.Encoding_RLE: + l.rleLen = l.rle.Flush() + default: + l.bit.Flush(false) + } +} + +// Encode encodes the slice of definition or repetition levels based on +// the currently configured encoding type and returns the number of +// values that were encoded. +func (l *LevelEncoder) Encode(lvls []int16) (nencoded int, err error) { + if l.rle == nil && l.bit == nil { + panic("parquet: level encoders are not initialized") + } + + switch l.encoding { + case format.Encoding_RLE: + defer func() { l.rleLen = l.rle.Flush() }() + for _, level := range lvls { + if err = l.rle.Put(uint64(level)); err != nil { + return + } + nencoded++ + } + + default: + defer l.bit.Flush(false) + for _, level := range lvls { + if err = l.bit.WriteValue(uint64(level), uint(l.bitWidth)); err != nil { + return + } + nencoded++ + } + } + return +} + +// Len returns the number of bytes that were written as Run Length encoded +// levels, this is only valid for run length encoding and will panic if using +// deprecated bit packed encoding. +func (l *LevelEncoder) Len() int { + if l.encoding != format.Encoding_RLE { + panic("parquet: level encoder, only implemented for RLE") + } + return l.rleLen +} + +// LevelDecoder handles the decoding of repetition and definition levels from a +// parquet file supporting bit packed and run length encoded values. +type LevelDecoder struct { + bitWidth int + remaining int // the number of values left to be decoded in the input data + maxLvl int16 + encoding format.Encoding + // only one of the following should ever be set at a time based on the + // encoding format. + rle *utils.RleDecoder + bit *utils.BitReader +} + +// SetData sets in the data to be decoded by subsequent calls by specifying the encoding type +// the maximum level (which is what determines the bit width), the number of values expected +// and the raw bytes to decode. Returns the number of bytes expected to be decoded. +func (l *LevelDecoder) SetData(encoding parquet.Encoding, maxLvl int16, nbuffered int, data []byte) (int, error) { + l.maxLvl = maxLvl + l.encoding = format.Encoding(encoding) + l.remaining = nbuffered + l.bitWidth = bits.Len64(uint64(maxLvl)) + + switch encoding { + case parquet.Encodings.RLE: + if len(data) < 4 { + return 0, xerrors.New("parquet: received invalid levels (corrupt data page?)") + } + + nbytes := int32(binary.LittleEndian.Uint32(data[:4])) + if nbytes < 0 || nbytes > int32(len(data)-4) { + return 0, xerrors.New("parquet: received invalid number of bytes (corrupt data page?)") + } + + buf := data[4:] + if l.rle == nil { + l.rle = utils.NewRleDecoder(bytes.NewReader(buf), l.bitWidth) + } else { + l.rle.Reset(bytes.NewReader(buf), l.bitWidth) + } + return int(nbytes) + 4, nil + case parquet.Encodings.BitPacked: + nbits, ok := overflow.Mul(nbuffered, l.bitWidth) + if !ok { + return 0, xerrors.New("parquet: number of buffered values too large (corrupt data page?)") + } + + nbytes := bitutil.BytesForBits(int64(nbits)) + if nbytes < 0 || nbytes > int64(len(data)) { + return 0, xerrors.New("parquet: recieved invalid number of bytes (corrupt data page?)") + } + if l.bit == nil { + l.bit = utils.NewBitReader(bytes.NewReader(data)) + } else { + l.bit.Reset(bytes.NewReader(data)) + } + return int(nbytes), nil + default: + return 0, xerrors.Errorf("parquet: unknown encoding type for levels '%s'", encoding) + } +} + +// SetDataV2 is the same as SetData but only for DataPageV2 pages and only supports +// run length encoding. +func (l *LevelDecoder) SetDataV2(nbytes int32, maxLvl int16, nbuffered int, data []byte) error { + if nbytes < 0 { + return xerrors.New("parquet: invalid page header (corrupt data page?)") + } + + l.maxLvl = maxLvl + l.encoding = format.Encoding_RLE + l.remaining = nbuffered + l.bitWidth = bits.Len64(uint64(maxLvl)) + + if l.rle == nil { + l.rle = utils.NewRleDecoder(bytes.NewReader(data), l.bitWidth) + } else { + l.rle.Reset(bytes.NewReader(data), l.bitWidth) + } + return nil +} + +// Decode decodes the bytes that were set with SetData into the slice of levels +// returning the total number of levels that were decoded and the number of +// values which had a level equal to the max level, indicating how many physical +// values exist to be read. +func (l *LevelDecoder) Decode(levels []int16) (int, int64) { + var ( + buf [1024]uint64 + totaldecoded int + decoded int + valsToRead int64 + ) + + n := utils.Min(int64(l.remaining), int64(len(levels))) + for n > 0 { + batch := utils.Min(1024, n) + switch l.encoding { + case format.Encoding_RLE: + decoded = l.rle.GetBatch(buf[:batch]) + case format.Encoding_BIT_PACKED: + decoded, _ = l.bit.GetBatch(uint(l.bitWidth), buf[:batch]) + } + l.remaining -= decoded + totaldecoded += decoded + n -= batch + + for idx, val := range buf[:decoded] { + lvl := int16(val) + levels[idx] = lvl + if lvl == l.maxLvl { + valsToRead++ + } + } + levels = levels[decoded:] + } + + return totaldecoded, valsToRead +} diff --git a/go/parquet/internal/encoding/levels_test.go b/go/parquet/internal/encoding/levels_test.go new file mode 100644 index 0000000000000..fe93b4097e1e0 --- /dev/null +++ b/go/parquet/internal/encoding/levels_test.go @@ -0,0 +1,292 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package encoding_test + +import ( + "encoding/binary" + "strconv" + "testing" + + "github.com/apache/arrow/go/arrow" + "github.com/apache/arrow/go/arrow/memory" + "github.com/apache/arrow/go/parquet" + "github.com/apache/arrow/go/parquet/internal/encoding" + "github.com/stretchr/testify/assert" +) + +func generateLevels(minRepeat, maxRepeat int, maxLevel int16) []int16 { + // for each repetition count up to max repeat + ret := make([]int16, 0) + for rep := minRepeat; rep <= maxRepeat; rep++ { + var ( + repCount = 1 << rep + val int16 = 0 + bwidth = 0 + ) + // generate levels for repetition count up to max level + for val <= maxLevel { + for i := 0; i < repCount; i++ { + ret = append(ret, val) + } + val = int16((2 << bwidth) - 1) + bwidth++ + } + } + return ret +} + +func encodeLevels(t *testing.T, enc parquet.Encoding, maxLvl int16, numLevels int, input []int16) []byte { + var ( + encoder encoding.LevelEncoder + lvlCount = 0 + buf = encoding.NewBufferWriter(2*numLevels, memory.DefaultAllocator) + ) + + if enc == parquet.Encodings.RLE { + buf.SetOffset(arrow.Int32SizeBytes) + // leave space to write the rle length value + encoder.Init(enc, maxLvl, buf) + lvlCount, _ = encoder.Encode(input) + buf.SetOffset(0) + arrow.Int32Traits.CastFromBytes(buf.Bytes())[0] = int32(encoder.Len()) + } else { + encoder.Init(enc, maxLvl, buf) + lvlCount, _ = encoder.Encode(input) + } + + assert.Equal(t, numLevels, lvlCount) + return buf.Bytes() +} + +func verifyDecodingLvls(t *testing.T, enc parquet.Encoding, maxLvl int16, input []int16, buf []byte) { + var ( + decoder encoding.LevelDecoder + lvlCount = 0 + numLevels = len(input) + output = make([]int16, numLevels) + decodeCount = 4 + numInnerLevels = numLevels / decodeCount + ) + + // decode levels and test with multiple decode calls + _, err := decoder.SetData(enc, maxLvl, numLevels, buf) + assert.NoError(t, err) + // try multiple decoding on a single setdata call + for ct := 0; ct < decodeCount; ct++ { + offset := ct * numInnerLevels + lvlCount, _ = decoder.Decode(output[:numInnerLevels]) + assert.Equal(t, numInnerLevels, lvlCount) + assert.Equal(t, input[offset:offset+numInnerLevels], output[:numInnerLevels]) + } + + // check the remaining levels + var ( + levelsCompleted = decodeCount * (numLevels / decodeCount) + remaining = numLevels - levelsCompleted + ) + + if remaining > 0 { + lvlCount, _ = decoder.Decode(output[:remaining]) + assert.Equal(t, remaining, lvlCount) + assert.Equal(t, input[levelsCompleted:], output[:remaining]) + } + // test decode zero values + lvlCount, _ = decoder.Decode(output[:1]) + assert.Zero(t, lvlCount) +} + +func verifyDecodingMultipleSetData(t *testing.T, enc parquet.Encoding, max int16, input []int16, buf [][]byte) { + var ( + decoder encoding.LevelDecoder + lvlCount = 0 + setdataCount = len(buf) + numLevels = len(input) / setdataCount + output = make([]int16, numLevels) + ) + + for ct := 0; ct < setdataCount; ct++ { + offset := ct * numLevels + assert.Len(t, output, numLevels) + _, err := decoder.SetData(enc, max, numLevels, buf[ct]) + assert.NoError(t, err) + lvlCount, _ = decoder.Decode(output) + assert.Equal(t, numLevels, lvlCount) + assert.Equal(t, input[offset:offset+numLevels], output) + } +} + +func TestLevelsDecodeMultipleBitWidth(t *testing.T) { + t.Parallel() + // Test levels with maximum bit-width from 1 to 8 + // increase the repetition count for each iteration by a factor of 2 + var ( + minRepeat = 0 + maxRepeat = 7 // 128 + maxBitWidth = 8 + input []int16 + buf []byte + encodings = [2]parquet.Encoding{parquet.Encodings.RLE, parquet.Encodings.BitPacked} + ) + + for _, enc := range encodings { + t.Run(enc.String(), func(t *testing.T) { + // bitpacked requires a sequence of at least 8 + if enc == parquet.Encodings.BitPacked { + minRepeat = 3 + } + // for each max bit width + for bitWidth := 1; bitWidth <= maxBitWidth; bitWidth++ { + t.Run(strconv.Itoa(bitWidth), func(t *testing.T) { + max := int16((1 << bitWidth) - 1) + // generate levels + input = generateLevels(minRepeat, maxRepeat, max) + assert.NotPanics(t, func() { + buf = encodeLevels(t, enc, max, len(input), input) + }) + assert.NotPanics(t, func() { + verifyDecodingLvls(t, enc, max, input, buf) + }) + }) + } + }) + } +} + +func TestLevelsDecodeMultipleSetData(t *testing.T) { + t.Parallel() + + var ( + minRepeat = 3 + maxRepeat = 7 + bitWidth = 8 + maxLevel = int16((1 << bitWidth) - 1) + encodings = [2]parquet.Encoding{parquet.Encodings.RLE, parquet.Encodings.BitPacked} + ) + + input := generateLevels(minRepeat, maxRepeat, maxLevel) + + var ( + numLevels = len(input) + setdataFactor = 8 + splitLevelSize = numLevels / setdataFactor + buf = make([][]byte, setdataFactor) + ) + + for _, enc := range encodings { + t.Run(enc.String(), func(t *testing.T) { + for rf := 0; rf < setdataFactor; rf++ { + offset := rf * splitLevelSize + assert.NotPanics(t, func() { + buf[rf] = encodeLevels(t, enc, maxLevel, splitLevelSize, input[offset:offset+splitLevelSize]) + }) + } + assert.NotPanics(t, func() { + verifyDecodingMultipleSetData(t, enc, maxLevel, input, buf) + }) + }) + } +} + +func TestMinimumBufferSize(t *testing.T) { + t.Parallel() + + const numToEncode = 1024 + levels := make([]int16, numToEncode) + + for idx := range levels { + if idx%9 == 0 { + levels[idx] = 0 + } else { + levels[idx] = 1 + } + } + + output := encoding.NewBufferWriter(0, memory.DefaultAllocator) + + var encoder encoding.LevelEncoder + encoder.Init(parquet.Encodings.RLE, 1, output) + count, _ := encoder.Encode(levels) + assert.Equal(t, numToEncode, count) +} + +func TestMinimumBufferSize2(t *testing.T) { + t.Parallel() + + // test the worst case for bit_width=2 consisting of + // LiteralRun(size=8) + // RepeatedRun(size=8) + // LiteralRun(size=8) + // ... + const numToEncode = 1024 + levels := make([]int16, numToEncode) + + for idx := range levels { + // This forces a literal run of 00000001 + // followed by eight 1s + if (idx % 16) < 7 { + levels[idx] = 0 + } else { + levels[idx] = 1 + } + } + + for bitWidth := int16(1); bitWidth <= 8; bitWidth++ { + output := encoding.NewBufferWriter(0, memory.DefaultAllocator) + + var encoder encoding.LevelEncoder + encoder.Init(parquet.Encodings.RLE, bitWidth, output) + count, _ := encoder.Encode(levels) + assert.Equal(t, numToEncode, count) + } +} + +func TestEncodeDecodeLevels(t *testing.T) { + t.Parallel() + const numToEncode = 2048 + levels := make([]int16, numToEncode) + numones := 0 + for idx := range levels { + if (idx % 16) < 7 { + levels[idx] = 0 + } else { + levels[idx] = 1 + numones++ + } + } + + output := encoding.NewBufferWriter(0, memory.DefaultAllocator) + + var encoder encoding.LevelEncoder + encoder.Init(parquet.Encodings.RLE, 1, output) + count, _ := encoder.Encode(levels) + assert.Equal(t, numToEncode, count) + encoder.Flush() + + buf := output.Bytes() + var prefix [4]byte + binary.LittleEndian.PutUint32(prefix[:], uint32(len(buf))) + + var decoder encoding.LevelDecoder + _, err := decoder.SetData(parquet.Encodings.RLE, 1, numToEncode, append(prefix[:], buf...)) + assert.NoError(t, err) + + var levelOut [numToEncode]int16 + total, vals := decoder.Decode(levelOut[:]) + assert.EqualValues(t, numToEncode, total) + assert.EqualValues(t, numones, vals) + assert.Equal(t, levels, levelOut[:]) +} diff --git a/go/parquet/internal/encoding/memo_table.go b/go/parquet/internal/encoding/memo_table.go new file mode 100644 index 0000000000000..9a04e6e0d025c --- /dev/null +++ b/go/parquet/internal/encoding/memo_table.go @@ -0,0 +1,380 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package encoding + +import ( + "math" + "unsafe" + + "github.com/apache/arrow/go/arrow" + "github.com/apache/arrow/go/arrow/array" + "github.com/apache/arrow/go/arrow/memory" + "github.com/apache/arrow/go/parquet" + "github.com/apache/arrow/go/parquet/internal/hashing" +) + +//go:generate go run ../../../arrow/_tools/tmpl/main.go -i -data=physical_types.tmpldata memo_table_types.gen.go.tmpl + +// MemoTable interface that can be used to swap out implementations of the hash table +// used for handling dictionary encoding. Dictionary encoding is built against this interface +// to make it easy for code generation and changing implementations. +// +// Values should remember the order they are inserted to generate a valid dictionary index +type MemoTable interface { + // Reset drops everything in the table allowing it to be reused + Reset() + // Size returns the current number of unique values stored in the table + // including whether or not a null value has been passed in using GetOrInsertNull + Size() int + // CopyValues populates out with the values currently in the table, out must + // be a slice of the appropriate type for the table type. + CopyValues(out interface{}) + // CopyValuesSubset is like CopyValues but only copies a subset of values starting + // at the indicated index. + CopyValuesSubset(start int, out interface{}) + // Get returns the index of the table the specified value is, and a boolean indicating + // whether or not the value was found in the table. Will panic if val is not the appropriate + // type for the underlying table. + Get(val interface{}) (int, bool) + // GetOrInsert is the same as Get, except if the value is not currently in the table it will + // be inserted into the table. + GetOrInsert(val interface{}) (idx int, existed bool, err error) + // GetNull returns the index of the null value and whether or not it was found in the table + GetNull() (int, bool) + // GetOrInsertNull returns the index of the null value, if it didn't already exist in the table, + // it is inserted. + GetOrInsertNull() (idx int, existed bool) +} + +// BinaryMemoTable is an extension of the MemoTable interface adding extra methods +// for handling byte arrays/strings/fixed length byte arrays. +type BinaryMemoTable interface { + MemoTable + // ValuesSize returns the total number of bytes needed to copy all of the values + // from this table. + ValuesSize() int + // CopyOffsets populates out with the start and end offsets of each value in the + // table data. Out should be sized to Size()+1 to accomodate all of the offsets. + CopyOffsets(out []int8) + // CopyOffsetsSubset is like CopyOffsets but only gets a subset of the offsets + // starting at the specified index. + CopyOffsetsSubset(start int, out []int8) + // CopyFixedWidthValues exists to cope with the fact that the table doesn't track + // the fixed width when inserting the null value into the databuffer populating + // a zero length byte slice for the null value (if found). + CopyFixedWidthValues(start int, width int, out []byte) + // VisitValues calls visitFn on each value in the table starting with the index specified + VisitValues(start int, visitFn func([]byte)) + // Retain increases the reference count of the separately stored binary data that is + // kept alongside the table which contains all of the values in the table. This is + // safe to call simultaneously across multiple goroutines. + Retain() + // Release decreases the reference count by 1 of the separately stored binary data + // kept alongside the table containing the values. When the reference count goes to + // 0, the memory is freed. This is safe to call across multiple goroutines simultaneoulsy. + Release() +} + +// NewInt32Dictionary returns a memotable interface for use with Int32 values only +func NewInt32Dictionary() MemoTable { + return hashing.NewInt32MemoTable(0) +} + +// NewInt64Dictionary returns a memotable interface for use with Int64 values only +func NewInt64Dictionary() MemoTable { + return hashing.NewInt64MemoTable(0) +} + +// NewFloat32Dictionary returns a memotable interface for use with Float32 values only +func NewFloat32Dictionary() MemoTable { + return hashing.NewFloat32MemoTable(0) +} + +// NewFloat64Dictionary returns a memotable interface for use with Float64 values only +func NewFloat64Dictionary() MemoTable { + return hashing.NewFloat64MemoTable(0) +} + +// NewBinaryDictionary returns a memotable interface for use with strings, byte slices, +// parquet.ByteArray and parquet.FixedLengthByteArray only. +func NewBinaryDictionary(mem memory.Allocator) BinaryMemoTable { + return hashing.NewBinaryMemoTable(mem, 0, -1) +} + +const keyNotFound = hashing.KeyNotFound + +// standard map based implementation of a binary memotable which is only kept around +// currently to be used as a benchmark against the memotables in the internal/hashing +// module as a baseline comparison. + +func NewBinaryMemoTable(mem memory.Allocator) BinaryMemoTable { + return &binaryMemoTableImpl{ + table: make(map[string]int), + nullIndex: keyNotFound, + builder: array.NewBinaryBuilder(mem, arrow.BinaryTypes.Binary), + } +} + +type binaryMemoTableImpl struct { + table map[string]int + builder *array.BinaryBuilder + nullIndex int +} + +func (m *binaryMemoTableImpl) Reset() { + m.table = make(map[string]int) + m.nullIndex = keyNotFound + m.builder.NewArray().Release() +} + +func (m *binaryMemoTableImpl) CopyValues(out interface{}) { + m.CopyValuesSubset(0, out) +} + +func (m *binaryMemoTableImpl) GetNull() (int, bool) { + return m.nullIndex, m.nullIndex != keyNotFound +} + +func (m *binaryMemoTableImpl) ValuesSize() int { + return m.builder.DataLen() +} + +func (m *binaryMemoTableImpl) Size() int { + sz := len(m.table) + if _, ok := m.GetNull(); ok { + sz++ + } + return sz +} + +func (m *binaryMemoTableImpl) valAsString(val interface{}) string { + switch v := val.(type) { + case string: + return v + case []byte: + return *(*string)(unsafe.Pointer(&v)) + case parquet.ByteArray: + return *(*string)(unsafe.Pointer(&v)) + case parquet.FixedLenByteArray: + return *(*string)(unsafe.Pointer(&v)) + default: + panic("invalid type for value in binarymemotable") + } +} + +func (m *binaryMemoTableImpl) Get(val interface{}) (int, bool) { + key := m.valAsString(val) + if p, ok := m.table[key]; ok { + return p, true + } + return keyNotFound, false +} + +func (m *binaryMemoTableImpl) GetOrInsert(val interface{}) (idx int, found bool, err error) { + key := m.valAsString(val) + idx, found = m.table[key] + if !found { + idx = m.Size() + m.builder.AppendString(key) + m.table[key] = idx + } + return +} + +func (m *binaryMemoTableImpl) GetOrInsertNull() (idx int, found bool) { + idx, found = m.GetNull() + if !found { + idx = m.Size() + m.nullIndex = idx + m.builder.AppendNull() + } + return +} + +func (m *binaryMemoTableImpl) findOffset(idx int) uintptr { + val := m.builder.Value(idx) + for len(val) == 0 { + idx++ + if idx >= m.builder.Len() { + break + } + val = m.builder.Value(idx) + } + if len(val) != 0 { + return uintptr(unsafe.Pointer(&val[0])) + } + return uintptr(m.builder.DataLen()) + m.findOffset(0) +} + +func (m *binaryMemoTableImpl) CopyValuesSubset(start int, out interface{}) { + var ( + first = m.findOffset(0) + offset = m.findOffset(int(start)) + length = m.builder.DataLen() - int(offset-first) + ) + + outval := out.([]byte) + copy(outval, m.builder.Value(start)[0:length]) +} + +func (m *binaryMemoTableImpl) CopyFixedWidthValues(start, width int, out []byte) { + +} + +func (m *binaryMemoTableImpl) CopyOffsetsSubset(start int, out []int8) { + if m.builder.Len() <= start { + return + } + + first := m.findOffset(0) + delta := m.findOffset(start) + for i := start; i < m.Size(); i++ { + offset := int8(m.findOffset(i) - delta) + out[i-start] = offset + } + + out[m.Size()-start] = int8(m.builder.DataLen() - int(delta) - int(first)) +} + +func (m *binaryMemoTableImpl) CopyOffsets(out []int8) { + m.CopyOffsetsSubset(0, out) +} + +func (m *binaryMemoTableImpl) VisitValues(start int, visitFn func([]byte)) { + for i := int(start); i < m.Size(); i++ { + visitFn(m.builder.Value(i)) + } +} + +func (m *binaryMemoTableImpl) Release() { + m.builder.Release() +} + +func (m *binaryMemoTableImpl) Retain() { + m.builder.Retain() +} + +// standard map based implementation of a float64 memotable which is only kept around +// currently to be used as a benchmark against the memotables in the internal/hashing +// module as a baseline comparison. + +func NewFloat64MemoTable(memory.Allocator) MemoTable { + return &float64MemoTableImpl{ + table: make(map[float64]struct { + value float64 + memoIndex int + }), + nullIndex: keyNotFound, + nanIndex: keyNotFound, + } +} + +type float64MemoTableImpl struct { + table map[float64]struct { + value float64 + memoIndex int + } + nullIndex int + nanIndex int +} + +func (m *float64MemoTableImpl) Reset() { + m.table = make(map[float64]struct { + value float64 + memoIndex int + }) + m.nullIndex = keyNotFound + m.nanIndex = keyNotFound +} + +func (m *float64MemoTableImpl) GetNull() (int, bool) { + return m.nullIndex, m.nullIndex != keyNotFound +} + +func (m *float64MemoTableImpl) Size() int { + sz := len(m.table) + if _, ok := m.GetNull(); ok { + sz++ + } + if m.nanIndex != keyNotFound { + sz++ + } + return sz +} + +func (m *float64MemoTableImpl) GetOrInsertNull() (idx int, found bool) { + idx, found = m.GetNull() + if !found { + idx = m.Size() + m.nullIndex = idx + } + return +} + +func (m *float64MemoTableImpl) Get(val interface{}) (int, bool) { + v := val.(float64) + if p, ok := m.table[v]; ok { + return p.memoIndex, true + } + if math.IsNaN(v) && m.nanIndex != keyNotFound { + return m.nanIndex, true + } + return keyNotFound, false +} + +func (m *float64MemoTableImpl) GetOrInsert(val interface{}) (idx int, found bool, err error) { + v := val.(float64) + if math.IsNaN(v) { + if m.nanIndex == keyNotFound { + idx = m.Size() + m.nanIndex = idx + } else { + idx = m.nanIndex + found = true + } + return + } + + p, ok := m.table[v] + if ok { + idx = p.memoIndex + } else { + idx = m.Size() + p.value = v + p.memoIndex = idx + m.table[v] = p + found = true + } + return +} + +func (m *float64MemoTableImpl) CopyValues(out interface{}) { + m.CopyValuesSubset(0, out) +} + +func (m *float64MemoTableImpl) CopyValuesSubset(start int, out interface{}) { + outval := out.([]float64) + for _, v := range m.table { + idx := v.memoIndex - start + if idx >= 0 { + outval[idx] = v.value + } + } + if m.nanIndex != keyNotFound { + outval[m.nanIndex] = math.NaN() + } +} diff --git a/go/parquet/internal/encoding/memo_table_test.go b/go/parquet/internal/encoding/memo_table_test.go new file mode 100644 index 0000000000000..82432763f7404 --- /dev/null +++ b/go/parquet/internal/encoding/memo_table_test.go @@ -0,0 +1,291 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package encoding_test + +import ( + "math" + "testing" + + "github.com/apache/arrow/go/arrow/memory" + "github.com/apache/arrow/go/parquet/internal/encoding" + "github.com/apache/arrow/go/parquet/internal/hashing" + "github.com/stretchr/testify/suite" +) + +type MemoTableTestSuite struct { + suite.Suite +} + +func TestMemoTable(t *testing.T) { + suite.Run(t, new(MemoTableTestSuite)) +} + +func (m *MemoTableTestSuite) assertGetNotFound(table encoding.MemoTable, v interface{}) { + _, ok := table.Get(v) + m.False(ok) +} + +func (m *MemoTableTestSuite) assertGet(table encoding.MemoTable, v interface{}, expected int) { + idx, ok := table.Get(v) + m.Equal(expected, idx) + m.True(ok) +} + +func (m *MemoTableTestSuite) assertGetOrInsert(table encoding.MemoTable, v interface{}, expected int) { + idx, _, err := table.GetOrInsert(v) + m.NoError(err) + m.Equal(expected, idx) +} + +func (m *MemoTableTestSuite) assertGetNullNotFound(table encoding.MemoTable) { + _, ok := table.GetNull() + m.False(ok) +} + +func (m *MemoTableTestSuite) assertGetNull(table encoding.MemoTable, expected int) { + idx, ok := table.GetNull() + m.Equal(expected, idx) + m.True(ok) +} + +func (m *MemoTableTestSuite) assertGetOrInsertNull(table encoding.MemoTable, expected int) { + idx, _ := table.GetOrInsertNull() + m.Equal(expected, idx) +} + +func (m *MemoTableTestSuite) TestInt64() { + const ( + A int64 = 1234 + B int64 = 0 + C int64 = -98765321 + D int64 = 12345678901234 + E int64 = -1 + F int64 = 1 + G int64 = 9223372036854775807 + H int64 = -9223372036854775807 - 1 + ) + + // table := encoding.NewInt64MemoTable(nil) + table := hashing.NewInt64MemoTable(0) + m.Zero(table.Size()) + m.assertGetNotFound(table, A) + m.assertGetNullNotFound(table) + m.assertGetOrInsert(table, A, 0) + m.assertGetNotFound(table, B) + m.assertGetOrInsert(table, B, 1) + m.assertGetOrInsert(table, C, 2) + m.assertGetOrInsert(table, D, 3) + m.assertGetOrInsert(table, E, 4) + m.assertGetOrInsertNull(table, 5) + + m.assertGet(table, A, 0) + m.assertGetOrInsert(table, A, 0) + m.assertGet(table, E, 4) + m.assertGetOrInsert(table, E, 4) + + m.assertGetOrInsert(table, F, 6) + m.assertGetOrInsert(table, G, 7) + m.assertGetOrInsert(table, H, 8) + + m.assertGetOrInsert(table, G, 7) + m.assertGetOrInsert(table, F, 6) + m.assertGetOrInsertNull(table, 5) + m.assertGetOrInsert(table, E, 4) + m.assertGetOrInsert(table, D, 3) + m.assertGetOrInsert(table, C, 2) + m.assertGetOrInsert(table, B, 1) + m.assertGetOrInsert(table, A, 0) + + const sz int = 9 + m.Equal(sz, table.Size()) + m.Panics(func() { + values := make([]int32, sz) + table.CopyValues(values) + }, "should panic because wrong type") + m.Panics(func() { + values := make([]int64, sz-3) + table.CopyValues(values) + }, "should panic because out of bounds") + + { + values := make([]int64, sz) + table.CopyValues(values) + m.Equal([]int64{A, B, C, D, E, 0, F, G, H}, values) + } + { + const offset = 3 + values := make([]int64, sz-offset) + table.CopyValuesSubset(offset, values) + m.Equal([]int64{D, E, 0, F, G, H}, values) + } +} + +func (m *MemoTableTestSuite) TestFloat64() { + const ( + A float64 = 0.0 + B float64 = 1.5 + C float64 = -0.1 + ) + var ( + D = math.Inf(1) + E = -D + F = math.NaN() // uses Quiet NaN i.e. 0x7FF8000000000001 + G = math.Float64frombits(uint64(0x7FF0000000000001)) // test Signalling NaN + H = math.Float64frombits(uint64(0xFFF7FFFFFFFFFFFF)) // other NaN bit pattern + ) + + // table := encoding.NewFloat64MemoTable(nil) + table := hashing.NewFloat64MemoTable(0) + m.Zero(table.Size()) + m.assertGetNotFound(table, A) + m.assertGetNullNotFound(table) + m.assertGetOrInsert(table, A, 0) + m.assertGetNotFound(table, B) + m.assertGetOrInsert(table, B, 1) + m.assertGetOrInsert(table, C, 2) + m.assertGetOrInsert(table, D, 3) + m.assertGetOrInsert(table, E, 4) + m.assertGetOrInsert(table, F, 5) + m.assertGetOrInsert(table, G, 5) + m.assertGetOrInsert(table, H, 5) + + m.assertGet(table, A, 0) + m.assertGetOrInsert(table, A, 0) + m.assertGetOrInsert(table, B, 1) + m.assertGetOrInsert(table, C, 2) + m.assertGetOrInsert(table, D, 3) + m.assertGet(table, E, 4) + m.assertGetOrInsert(table, E, 4) + m.assertGet(table, F, 5) + m.assertGetOrInsert(table, F, 5) + m.assertGet(table, G, 5) + m.assertGetOrInsert(table, G, 5) + m.assertGet(table, H, 5) + m.assertGetOrInsert(table, H, 5) + + m.Equal(6, table.Size()) + expected := []float64{A, B, C, D, E, F} + m.Panics(func() { + values := make([]int32, 6) + table.CopyValues(values) + }, "should panic because wrong type") + m.Panics(func() { + values := make([]float64, 3) + table.CopyValues(values) + }, "should panic because out of bounds") + + values := make([]float64, len(expected)) + table.CopyValues(values) + for idx, ex := range expected { + if math.IsNaN(ex) { + m.True(math.IsNaN(values[idx])) + } else { + m.Equal(ex, values[idx]) + } + } +} + +func (m *MemoTableTestSuite) TestBinaryBasics() { + const ( + A = "" + B = "a" + C = "foo" + D = "bar" + E = "\000" + F = "\000trailing" + ) + + table := hashing.NewBinaryMemoTable(memory.DefaultAllocator, 0, -1) + defer table.Release() + + m.Zero(table.Size()) + m.assertGetNotFound(table, A) + m.assertGetNullNotFound(table) + m.assertGetOrInsert(table, A, 0) + m.assertGetNotFound(table, B) + m.assertGetOrInsert(table, B, 1) + m.assertGetOrInsert(table, C, 2) + m.assertGetOrInsert(table, D, 3) + m.assertGetOrInsert(table, E, 4) + m.assertGetOrInsert(table, F, 5) + m.assertGetOrInsertNull(table, 6) + + m.assertGet(table, A, 0) + m.assertGetOrInsert(table, A, 0) + m.assertGet(table, B, 1) + m.assertGetOrInsert(table, B, 1) + m.assertGetOrInsert(table, C, 2) + m.assertGetOrInsert(table, D, 3) + m.assertGetOrInsert(table, E, 4) + m.assertGet(table, F, 5) + m.assertGetOrInsert(table, F, 5) + m.assertGetNull(table, 6) + m.assertGetOrInsertNull(table, 6) + + m.Equal(7, table.Size()) + m.Equal(17, table.ValuesSize()) + + size := table.Size() + { + offsets := make([]int8, size+1) + table.CopyOffsets(offsets) + m.Equal([]int8{0, 0, 1, 4, 7, 8, 17, 17}, offsets) + + expectedValues := "afoobar" + expectedValues += "\000" + expectedValues += "\000" + expectedValues += "trailing" + values := make([]byte, 17) + table.CopyValues(values) + m.Equal(expectedValues, string(values)) + } + + { + startOffset := 4 + offsets := make([]int8, size+1-int(startOffset)) + table.CopyOffsetsSubset(startOffset, offsets) + m.Equal([]int8{0, 1, 10, 10}, offsets) + + expectedValues := "" + expectedValues += "\000" + expectedValues += "\000" + expectedValues += "trailing" + + values := make([]byte, 10) + table.CopyValuesSubset(startOffset, values) + m.Equal(expectedValues, string(values)) + } + + { + startOffset := 1 + values := make([]string, 0) + table.VisitValues(startOffset, func(b []byte) { + values = append(values, string(b)) + }) + m.Equal([]string{B, C, D, E, F, ""}, values) + } +} + +func (m *MemoTableTestSuite) TestBinaryEmpty() { + table := encoding.NewBinaryMemoTable(memory.DefaultAllocator) + defer table.Release() + + m.Zero(table.Size()) + offsets := make([]int8, 1) + table.CopyOffsetsSubset(0, offsets) + m.Equal(int8(0), offsets[0]) +} diff --git a/go/parquet/internal/encoding/memo_table_types.gen.go b/go/parquet/internal/encoding/memo_table_types.gen.go new file mode 100644 index 0000000000000..5c4812cbbebcc --- /dev/null +++ b/go/parquet/internal/encoding/memo_table_types.gen.go @@ -0,0 +1,366 @@ +// Code generated by memo_table_types.gen.go.tmpl. DO NOT EDIT. + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package encoding + +import ( + "github.com/apache/arrow/go/arrow/memory" + "github.com/apache/arrow/go/parquet" +) + +// standard map based implementation of memo tables which can be more efficient +// in some cases based on the uniqueness / amount / size of the data. +// these are left here for now for use in the benchmarks to compare against the +// custom hash table implementation in the internal/hashing package as a base +// benchmark comparison. + +func NewInt32MemoTable(memory.Allocator) MemoTable { + return &int32MemoTableImpl{ + table: make(map[int32]struct { + value int32 + memoIndex int + }), + nullIndex: keyNotFound, + } +} + +type int32MemoTableImpl struct { + table map[int32]struct { + value int32 + memoIndex int + } + nullIndex int +} + +func (m *int32MemoTableImpl) Reset() { + m.table = make(map[int32]struct { + value int32 + memoIndex int + }) + m.nullIndex = keyNotFound +} + +func (m *int32MemoTableImpl) GetNull() (int, bool) { + return m.nullIndex, m.nullIndex != keyNotFound +} + +func (m *int32MemoTableImpl) Size() int { + sz := len(m.table) + if _, ok := m.GetNull(); ok { + sz++ + } + return sz +} + +func (m *int32MemoTableImpl) GetOrInsertNull() (idx int, found bool) { + idx, found = m.GetNull() + if !found { + idx = m.Size() + m.nullIndex = idx + } + return +} + +func (m *int32MemoTableImpl) Get(val interface{}) (int, bool) { + v := val.(int32) + if p, ok := m.table[v]; ok { + return p.memoIndex, true + } + return keyNotFound, false +} + +func (m *int32MemoTableImpl) GetOrInsert(val interface{}) (idx int, found bool, err error) { + v := val.(int32) + p, ok := m.table[v] + if ok { + idx = p.memoIndex + } else { + idx = m.Size() + p.value = v + p.memoIndex = idx + m.table[v] = p + found = true + } + return +} + +func (m *int32MemoTableImpl) CopyValues(out interface{}) { + m.CopyValuesSubset(0, out) +} + +func (m *int32MemoTableImpl) CopyValuesSubset(start int, out interface{}) { + outval := out.([]int32) + for _, v := range m.table { + idx := v.memoIndex - start + if idx >= 0 { + outval[idx] = v.value + } + } +} + +func NewInt64MemoTable(memory.Allocator) MemoTable { + return &int64MemoTableImpl{ + table: make(map[int64]struct { + value int64 + memoIndex int + }), + nullIndex: keyNotFound, + } +} + +type int64MemoTableImpl struct { + table map[int64]struct { + value int64 + memoIndex int + } + nullIndex int +} + +func (m *int64MemoTableImpl) Reset() { + m.table = make(map[int64]struct { + value int64 + memoIndex int + }) + m.nullIndex = keyNotFound +} + +func (m *int64MemoTableImpl) GetNull() (int, bool) { + return m.nullIndex, m.nullIndex != keyNotFound +} + +func (m *int64MemoTableImpl) Size() int { + sz := len(m.table) + if _, ok := m.GetNull(); ok { + sz++ + } + return sz +} + +func (m *int64MemoTableImpl) GetOrInsertNull() (idx int, found bool) { + idx, found = m.GetNull() + if !found { + idx = m.Size() + m.nullIndex = idx + } + return +} + +func (m *int64MemoTableImpl) Get(val interface{}) (int, bool) { + v := val.(int64) + if p, ok := m.table[v]; ok { + return p.memoIndex, true + } + return keyNotFound, false +} + +func (m *int64MemoTableImpl) GetOrInsert(val interface{}) (idx int, found bool, err error) { + v := val.(int64) + p, ok := m.table[v] + if ok { + idx = p.memoIndex + } else { + idx = m.Size() + p.value = v + p.memoIndex = idx + m.table[v] = p + found = true + } + return +} + +func (m *int64MemoTableImpl) CopyValues(out interface{}) { + m.CopyValuesSubset(0, out) +} + +func (m *int64MemoTableImpl) CopyValuesSubset(start int, out interface{}) { + outval := out.([]int64) + for _, v := range m.table { + idx := v.memoIndex - start + if idx >= 0 { + outval[idx] = v.value + } + } +} + +func NewInt96MemoTable(memory.Allocator) MemoTable { + return &int96MemoTableImpl{ + table: make(map[parquet.Int96]struct { + value parquet.Int96 + memoIndex int + }), + nullIndex: keyNotFound, + } +} + +type int96MemoTableImpl struct { + table map[parquet.Int96]struct { + value parquet.Int96 + memoIndex int + } + nullIndex int +} + +func (m *int96MemoTableImpl) Reset() { + m.table = make(map[parquet.Int96]struct { + value parquet.Int96 + memoIndex int + }) + m.nullIndex = keyNotFound +} + +func (m *int96MemoTableImpl) GetNull() (int, bool) { + return m.nullIndex, m.nullIndex != keyNotFound +} + +func (m *int96MemoTableImpl) Size() int { + sz := len(m.table) + if _, ok := m.GetNull(); ok { + sz++ + } + return sz +} + +func (m *int96MemoTableImpl) GetOrInsertNull() (idx int, found bool) { + idx, found = m.GetNull() + if !found { + idx = m.Size() + m.nullIndex = idx + } + return +} + +func (m *int96MemoTableImpl) Get(val interface{}) (int, bool) { + v := val.(parquet.Int96) + if p, ok := m.table[v]; ok { + return p.memoIndex, true + } + return keyNotFound, false +} + +func (m *int96MemoTableImpl) GetOrInsert(val interface{}) (idx int, found bool, err error) { + v := val.(parquet.Int96) + p, ok := m.table[v] + if ok { + idx = p.memoIndex + } else { + idx = m.Size() + p.value = v + p.memoIndex = idx + m.table[v] = p + found = true + } + return +} + +func (m *int96MemoTableImpl) CopyValues(out interface{}) { + m.CopyValuesSubset(0, out) +} + +func (m *int96MemoTableImpl) CopyValuesSubset(start int, out interface{}) { + outval := out.([]parquet.Int96) + for _, v := range m.table { + idx := v.memoIndex - start + if idx >= 0 { + outval[idx] = v.value + } + } +} + +func NewFloat32MemoTable(memory.Allocator) MemoTable { + return &float32MemoTableImpl{ + table: make(map[float32]struct { + value float32 + memoIndex int + }), + nullIndex: keyNotFound, + } +} + +type float32MemoTableImpl struct { + table map[float32]struct { + value float32 + memoIndex int + } + nullIndex int +} + +func (m *float32MemoTableImpl) Reset() { + m.table = make(map[float32]struct { + value float32 + memoIndex int + }) + m.nullIndex = keyNotFound +} + +func (m *float32MemoTableImpl) GetNull() (int, bool) { + return m.nullIndex, m.nullIndex != keyNotFound +} + +func (m *float32MemoTableImpl) Size() int { + sz := len(m.table) + if _, ok := m.GetNull(); ok { + sz++ + } + return sz +} + +func (m *float32MemoTableImpl) GetOrInsertNull() (idx int, found bool) { + idx, found = m.GetNull() + if !found { + idx = m.Size() + m.nullIndex = idx + } + return +} + +func (m *float32MemoTableImpl) Get(val interface{}) (int, bool) { + v := val.(float32) + if p, ok := m.table[v]; ok { + return p.memoIndex, true + } + return keyNotFound, false +} + +func (m *float32MemoTableImpl) GetOrInsert(val interface{}) (idx int, found bool, err error) { + v := val.(float32) + p, ok := m.table[v] + if ok { + idx = p.memoIndex + } else { + idx = m.Size() + p.value = v + p.memoIndex = idx + m.table[v] = p + found = true + } + return +} + +func (m *float32MemoTableImpl) CopyValues(out interface{}) { + m.CopyValuesSubset(0, out) +} + +func (m *float32MemoTableImpl) CopyValuesSubset(start int, out interface{}) { + outval := out.([]float32) + for _, v := range m.table { + idx := v.memoIndex - start + if idx >= 0 { + outval[idx] = v.value + } + } +} diff --git a/go/parquet/internal/encoding/memo_table_types.gen.go.tmpl b/go/parquet/internal/encoding/memo_table_types.gen.go.tmpl new file mode 100644 index 0000000000000..0a0a7af29205c --- /dev/null +++ b/go/parquet/internal/encoding/memo_table_types.gen.go.tmpl @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package encoding + +import ( + "github.com/apache/arrow/go/parquet" +) + +// standard map based implementation of memo tables which can be more efficient +// in some cases based on the uniqueness / amount / size of the data. +// these are left here for now for use in the benchmarks to compare against the +// custom hash table implementation in the internal/hashing package as a base +// benchmark comparison. + +{{range .In}} +{{if and (ne .Name "ByteArray") (ne .Name "FixedLenByteArray") (ne .Name "Float64") (ne .Name "Boolean")}} +func New{{.Name}}MemoTable(memory.Allocator) MemoTable { + return &{{.lower}}MemoTableImpl{ + table: make(map[{{.name}}]struct{ + value {{.name}} + memoIndex int + }), + nullIndex: keyNotFound, + } +} + +type {{.lower}}MemoTableImpl struct { + table map[{{.name}}]struct{ + value {{.name}} + memoIndex int + } + nullIndex int +} + +func (m *{{.lower}}MemoTableImpl) Reset() { + m.table = make(map[{{.name}}]struct{ + value {{.name}} + memoIndex int + }) + m.nullIndex = keyNotFound +} + +func (m *{{.lower}}MemoTableImpl) GetNull() (int, bool) { + return m.nullIndex, m.nullIndex != keyNotFound +} + +func (m *{{.lower}}MemoTableImpl) Size() int { + sz := len(m.table) + if _, ok := m.GetNull(); ok { + sz++ + } + return sz +} + +func (m *{{.lower}}MemoTableImpl) GetOrInsertNull() (idx int, found bool) { + idx, found = m.GetNull() + if !found { + idx = m.Size() + m.nullIndex = idx + } + return +} + +func (m *{{.lower}}MemoTableImpl) Get(val interface{}) (int, bool) { + v := val.({{.name}}) + if p, ok := m.table[v]; ok { + return p.memoIndex, true + } + return keyNotFound, false +} + +func (m *{{.lower}}MemoTableImpl) GetOrInsert(val interface{}) (idx int, found bool, err error) { + v := val.({{.name}}) + p, ok := m.table[v] + if ok { + idx = p.memoIndex + } else { + idx = m.Size() + p.value = v + p.memoIndex = idx + m.table[v] = p + found = true + } + return +} + +func (m *{{.lower}}MemoTableImpl) CopyValues(out interface{}) { + m.CopyValuesSubset(0, out) +} + +func (m *{{.lower}}MemoTableImpl) CopyValuesSubset(start int, out interface{}) { + outval := out.([]{{.name}}) + for _, v := range m.table { + idx := v.memoIndex - start + if idx >= 0 { + outval[idx] = v.value + } + } +} +{{end}} +{{end}} diff --git a/go/parquet/internal/encoding/typed_encoder.gen.go b/go/parquet/internal/encoding/typed_encoder.gen.go index abcfd95142e15..192286f987c42 100644 --- a/go/parquet/internal/encoding/typed_encoder.gen.go +++ b/go/parquet/internal/encoding/typed_encoder.gen.go @@ -74,6 +74,9 @@ type int32EncoderTraits struct{} // Encoder returns an encoder for int32 type data, using the specified encoding type and whether or not // it should be dictionary encoded. func (int32EncoderTraits) Encoder(e format.Encoding, useDict bool, descr *schema.Column, mem memory.Allocator) TypedEncoder { + if useDict { + return &DictInt32Encoder{newDictEncoderBase(descr, NewInt32Dictionary(), mem)} + } switch e { case format.Encoding_PLAIN: @@ -287,6 +290,9 @@ type int64EncoderTraits struct{} // Encoder returns an encoder for int64 type data, using the specified encoding type and whether or not // it should be dictionary encoded. func (int64EncoderTraits) Encoder(e format.Encoding, useDict bool, descr *schema.Column, mem memory.Allocator) TypedEncoder { + if useDict { + return &DictInt64Encoder{newDictEncoderBase(descr, NewInt64Dictionary(), mem)} + } switch e { case format.Encoding_PLAIN: @@ -501,6 +507,9 @@ type int96EncoderTraits struct{} // it should be dictionary encoded. // dictionary encoding does not exist for this type and Encoder will panic if useDict is true func (int96EncoderTraits) Encoder(e format.Encoding, useDict bool, descr *schema.Column, mem memory.Allocator) TypedEncoder { + if useDict { + panic("parquet: no parquet.Int96 dictionary encoding") + } switch e { case format.Encoding_PLAIN: @@ -555,6 +564,9 @@ type float32EncoderTraits struct{} // Encoder returns an encoder for float32 type data, using the specified encoding type and whether or not // it should be dictionary encoded. func (float32EncoderTraits) Encoder(e format.Encoding, useDict bool, descr *schema.Column, mem memory.Allocator) TypedEncoder { + if useDict { + return &DictFloat32Encoder{newDictEncoderBase(descr, NewFloat32Dictionary(), mem)} + } switch e { case format.Encoding_PLAIN: @@ -756,6 +768,9 @@ type float64EncoderTraits struct{} // Encoder returns an encoder for float64 type data, using the specified encoding type and whether or not // it should be dictionary encoded. func (float64EncoderTraits) Encoder(e format.Encoding, useDict bool, descr *schema.Column, mem memory.Allocator) TypedEncoder { + if useDict { + return &DictFloat64Encoder{newDictEncoderBase(descr, NewFloat64Dictionary(), mem)} + } switch e { case format.Encoding_PLAIN: @@ -958,6 +973,9 @@ type boolEncoderTraits struct{} // it should be dictionary encoded. // dictionary encoding does not exist for this type and Encoder will panic if useDict is true func (boolEncoderTraits) Encoder(e format.Encoding, useDict bool, descr *schema.Column, mem memory.Allocator) TypedEncoder { + if useDict { + panic("parquet: no bool dictionary encoding") + } switch e { case format.Encoding_PLAIN: @@ -1012,6 +1030,9 @@ type byteArrayEncoderTraits struct{} // Encoder returns an encoder for byteArray type data, using the specified encoding type and whether or not // it should be dictionary encoded. func (byteArrayEncoderTraits) Encoder(e format.Encoding, useDict bool, descr *schema.Column, mem memory.Allocator) TypedEncoder { + if useDict { + return &DictByteArrayEncoder{newDictEncoderBase(descr, NewBinaryDictionary(mem), mem)} + } switch e { case format.Encoding_PLAIN: @@ -1217,6 +1238,9 @@ type fixedLenByteArrayEncoderTraits struct{} // Encoder returns an encoder for fixedLenByteArray type data, using the specified encoding type and whether or not // it should be dictionary encoded. func (fixedLenByteArrayEncoderTraits) Encoder(e format.Encoding, useDict bool, descr *schema.Column, mem memory.Allocator) TypedEncoder { + if useDict { + return &DictFixedLenByteArrayEncoder{newDictEncoderBase(descr, NewBinaryDictionary(mem), mem)} + } switch e { case format.Encoding_PLAIN: diff --git a/go/parquet/internal/encoding/typed_encoder.gen.go.tmpl b/go/parquet/internal/encoding/typed_encoder.gen.go.tmpl index 509266b6878b4..0667143ac0734 100644 --- a/go/parquet/internal/encoding/typed_encoder.gen.go.tmpl +++ b/go/parquet/internal/encoding/typed_encoder.gen.go.tmpl @@ -60,13 +60,13 @@ type {{.lower}}EncoderTraits struct{} // dictionary encoding does not exist for this type and Encoder will panic if useDict is true {{- end }} func ({{.lower}}EncoderTraits) Encoder(e format.Encoding, useDict bool, descr *schema.Column, mem memory.Allocator) TypedEncoder { - {{/* if useDict { + if useDict { {{- if or (eq .Name "Boolean") (eq .Name "Int96")}} panic("parquet: no {{.name}} dictionary encoding") {{- else}} return &Dict{{.Name}}Encoder{newDictEncoderBase(descr, New{{if and (ne .Name "ByteArray") (ne .Name "FixedLenByteArray")}}{{.Name}}Dictionary(){{else}}BinaryDictionary(mem){{end}}, mem)} {{- end}} - } */}} + } switch e { case format.Encoding_PLAIN: diff --git a/go/parquet/internal/encoding/types.go b/go/parquet/internal/encoding/types.go index fa3661e111928..ed3a5c8abce10 100644 --- a/go/parquet/internal/encoding/types.go +++ b/go/parquet/internal/encoding/types.go @@ -63,7 +63,8 @@ type TypedEncoder interface { EstimatedDataEncodedSize() int64 // FlushValues finishes up any unwritten data and returns the buffer of data passing // ownership to the caller, Release needs to be called on the Buffer to free the memory - FlushValues() Buffer + // if error is nil + FlushValues() (Buffer, error) // Encoding returns the type of encoding that this encoder operates with Encoding() parquet.Encoding // Allocator returns the allocator that was used when creating this encoder @@ -78,7 +79,7 @@ type DictEncoder interface { TypedEncoder // WriteIndices populates the byte slice with the final indexes of data and returns // the number of bytes written - WriteIndices(out []byte) int + WriteIndices(out []byte) (int, error) // DictEncodedSize returns the current size of the encoded dictionary index. DictEncodedSize() int // BitWidth returns the bitwidth needed to encode all of the index values based @@ -435,63 +436,3 @@ func (b *BufferWriter) Seek(offset int64, whence int) (int64, error) { func (b *BufferWriter) Tell() int64 { return int64(b.pos) } - -// MemoTable interface that can be used to swap out implementations of the hash table -// used for handling dictionary encoding. Dictionary encoding is built against this interface -// to make it easy for code generation and changing implementations. -// -// Values should remember the order they are inserted to generate a valid dictionary index -type MemoTable interface { - // Reset drops everything in the table allowing it to be reused - Reset() - // Size returns the current number of unique values stored in the table - // including whether or not a null value has been passed in using GetOrInsertNull - Size() int - // CopyValues populates out with the values currently in the table, out must - // be a slice of the appropriate type for the table type. - CopyValues(out interface{}) - // CopyValuesSubset is like CopyValues but only copies a subset of values starting - // at the indicated index. - CopyValuesSubset(start int, out interface{}) - // Get returns the index of the table the specified value is, and a boolean indicating - // whether or not the value was found in the table. Will panic if val is not the appropriate - // type for the underlying table. - Get(val interface{}) (int, bool) - // GetOrInsert is the same as Get, except if the value is not currently in the table it will - // be inserted into the table. - GetOrInsert(val interface{}) (idx int, existed bool, err error) - // GetNull returns the index of the null value and whether or not it was found in the table - GetNull() (int, bool) - // GetOrInsertNull returns the index of the null value, if it didn't already exist in the table, - // it is inserted. - GetOrInsertNull() (idx int, existed bool) -} - -// BinaryMemoTable is an extension of the MemoTable interface adding extra methods -// for handling byte arrays/strings/fixed length byte arrays. -type BinaryMemoTable interface { - MemoTable - // ValuesSize returns the total number of bytes needed to copy all of the values - // from this table. - ValuesSize() int - // CopyOffsets populates out with the start and end offsets of each value in the - // table data. Out should be sized to Size()+1 to accomodate all of the offsets. - CopyOffsets(out []int8) - // CopyOffsetsSubset is like CopyOffsets but only gets a subset of the offsets - // starting at the specified index. - CopyOffsetsSubset(start int, out []int8) - // CopyFixedWidthValues exists to cope with the fact that the table doesn't track - // the fixed width when inserting the null value into the databuffer populating - // a zero length byte slice for the null value (if found). - CopyFixedWidthValues(start int, width int, out []byte) - // VisitValues calls visitFn on each value in the table starting with the index specified - VisitValues(start int, visitFn func([]byte)) - // Retain increases the reference count of the separately stored binary data that is - // kept alongside the table which contains all of the values in the table. This is - // safe to call simultaneously across multiple goroutines. - Retain() - // Release decreases the reference count by 1 of the separately stored binary data - // kept alongside the table containing the values. When the reference count goes to - // 0, the memory is freed. This is safe to call across multiple goroutines simultaneoulsy. - Release() -} diff --git a/go/parquet/internal/hashing/hashing_test.go b/go/parquet/internal/hashing/hashing_test.go new file mode 100644 index 0000000000000..875424a9d494f --- /dev/null +++ b/go/parquet/internal/hashing/hashing_test.go @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hashing + +import ( + "math/rand" + "testing" + + "github.com/stretchr/testify/assert" +) + +func MakeDistinctIntegers(nvals int) map[int]bool { + r := rand.New(rand.NewSource(42)) + values := make(map[int]bool) + for len(values) < nvals { + values[r.Int()] = true + } + return values +} + +func MakeSequentialIntegers(nvals int) map[int]bool { + values := make(map[int]bool) + for i := 0; i < nvals; i++ { + values[i] = true + } + return values +} + +func MakeDistinctStrings(nvals int) map[string]bool { + values := make(map[string]bool) + + r := rand.New(rand.NewSource(42)) + + max := 'z' + min := '0' + for len(values) < nvals { + data := make([]byte, r.Intn(24)) + for idx := range data { + data[idx] = byte(r.Intn(int(max-min+1)) + int(min)) + } + values[string(data)] = true + } + return values +} + +func TestHashingQualityInt(t *testing.T) { + const nvalues = 10000 + + tests := []struct { + name string + values map[int]bool + quality float64 + }{ + {"distinct", MakeDistinctIntegers(nvalues), 0.96}, + {"sequential", MakeSequentialIntegers(nvalues), 0.96}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hashes := make(map[uint64]bool) + for k := range tt.values { + hashes[hashInt(uint64(k), 0)] = true + hashes[hashInt(uint64(k), 1)] = true + } + assert.GreaterOrEqual(t, float64(len(hashes)), tt.quality*float64(2*len(tt.values))) + }) + } +} + +func TestHashingBoundsStrings(t *testing.T) { + sizes := []int{1, 2, 3, 4, 5, 7, 8, 9, 15, 16, 17, 18, 19, 20, 21} + for _, s := range sizes { + str := make([]byte, s) + for idx := range str { + str[idx] = uint8(idx) + } + + h := hash(str, 1) + diff := 0 + for i := 0; i < 120; i++ { + str[len(str)-1] = uint8(i) + if hash(str, 1) != h { + diff++ + } + } + assert.GreaterOrEqual(t, diff, 118) + } +} + +func TestHashingQualityString(t *testing.T) { + const nvalues = 10000 + values := MakeDistinctStrings(nvalues) + + hashes := make(map[uint64]bool) + for k := range values { + hashes[hashString(k, 0)] = true + hashes[hashString(k, 1)] = true + } + assert.GreaterOrEqual(t, float64(len(hashes)), 0.96*float64(2*len(values))) +} diff --git a/go/parquet/internal/hashing/types.tmpldata b/go/parquet/internal/hashing/types.tmpldata new file mode 100644 index 0000000000000..2e97e9814e078 --- /dev/null +++ b/go/parquet/internal/hashing/types.tmpldata @@ -0,0 +1,18 @@ +[ + { + "Name": "Int32", + "name": "int32" + }, + { + "Name": "Int64", + "name": "int64" + }, + { + "Name": "Float32", + "name": "float32" + }, + { + "Name": "Float64", + "name": "float64" + } +] diff --git a/go/parquet/internal/hashing/xxh3_memo_table.gen.go b/go/parquet/internal/hashing/xxh3_memo_table.gen.go new file mode 100644 index 0000000000000..b2ebd87aaa091 --- /dev/null +++ b/go/parquet/internal/hashing/xxh3_memo_table.gen.go @@ -0,0 +1,1013 @@ +// Code generated by xxh3_memo_table.gen.go.tmpl. DO NOT EDIT. + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hashing + +import ( + "math" + + "github.com/apache/arrow/go/arrow/bitutil" +) + +type payloadInt32 struct { + val int32 + memoIdx int32 +} + +type entryInt32 struct { + h uint64 + payload payloadInt32 +} + +func (e entryInt32) Valid() bool { return e.h != sentinel } + +// Int32HashTable is a hashtable specifically for int32 that +// is utilized with the MemoTable to generalize interactions for easier +// implementation of dictionaries without losing performance. +type Int32HashTable struct { + cap uint64 + capMask uint64 + size uint64 + + entries []entryInt32 +} + +// NewInt32HashTable returns a new hash table for int32 values +// initialized with the passed in capacity or 32 whichever is larger. +func NewInt32HashTable(cap uint64) *Int32HashTable { + initCap := uint64(bitutil.NextPowerOf2(int(max(cap, 32)))) + ret := &Int32HashTable{cap: initCap, capMask: initCap - 1, size: 0} + ret.entries = make([]entryInt32, initCap) + return ret +} + +// Reset drops all of the values in this hash table and re-initializes it +// with the specified initial capacity as if by calling New, but without having +// to reallocate the object. +func (h *Int32HashTable) Reset(cap uint64) { + h.cap = uint64(bitutil.NextPowerOf2(int(max(cap, 32)))) + h.capMask = h.cap - 1 + h.size = 0 + h.entries = make([]entryInt32, h.cap) +} + +// CopyValues is used for copying the values out of the hash table into the +// passed in slice, in the order that they were first inserted +func (h *Int32HashTable) CopyValues(out []int32) { + h.CopyValuesSubset(0, out) +} + +// CopyValuesSubset copies a subset of the values in the hashtable out, starting +// with the value at start, in the order that they were inserted. +func (h *Int32HashTable) CopyValuesSubset(start int, out []int32) { + h.VisitEntries(func(e *entryInt32) { + idx := e.payload.memoIdx - int32(start) + if idx >= 0 { + out[idx] = e.payload.val + } + }) +} + +func (h *Int32HashTable) needUpsize() bool { return h.size*uint64(loadFactor) >= h.cap } + +func (Int32HashTable) fixHash(v uint64) uint64 { + if v == sentinel { + return 42 + } + return v +} + +// Lookup retrieves the entry for a given hash value assuming it's payload value returns +// true when passed to the cmp func. Returns a pointer to the entry for the given hash value, +// and a boolean as to whether it was found. It is not safe to use the pointer if the bool is false. +func (h *Int32HashTable) Lookup(v uint64, cmp func(int32) bool) (*entryInt32, bool) { + idx, ok := h.lookup(v, h.capMask, cmp) + return &h.entries[idx], ok +} + +func (h *Int32HashTable) lookup(v uint64, szMask uint64, cmp func(int32) bool) (uint64, bool) { + const perturbShift uint8 = 5 + + var ( + idx uint64 + perturb uint64 + e *entryInt32 + ) + + v = h.fixHash(v) + idx = v & szMask + perturb = (v >> uint64(perturbShift)) + 1 + + for { + e = &h.entries[idx] + if e.h == v && cmp(e.payload.val) { + return idx, true + } + + if e.h == sentinel { + return idx, false + } + + // perturbation logic inspired from CPython's set/dict object + // the goal is that all 64 bits of unmasked hash value eventually + // participate int he probing sequence, to minimize clustering + idx = (idx + perturb) & szMask + perturb = (perturb >> uint64(perturbShift)) + 1 + } +} + +func (h *Int32HashTable) upsize(newcap uint64) error { + newMask := newcap - 1 + + oldEntries := h.entries + h.entries = make([]entryInt32, newcap) + for _, e := range oldEntries { + if e.Valid() { + idx, _ := h.lookup(e.h, newMask, func(int32) bool { return false }) + h.entries[idx] = e + } + } + h.cap = newcap + h.capMask = newMask + return nil +} + +// Insert updates the given entry with the provided hash value, payload value and memo index. +// The entry pointer must have been retrieved via lookup in order to actually insert properly. +func (h *Int32HashTable) Insert(e *entryInt32, v uint64, val int32, memoIdx int32) error { + e.h = h.fixHash(v) + e.payload.val = val + e.payload.memoIdx = memoIdx + h.size++ + + if h.needUpsize() { + h.upsize(h.cap * uint64(loadFactor) * 2) + } + return nil +} + +// VisitEntries will call the passed in function on each *valid* entry in the hash table, +// a valid entry being one which has had a value inserted into it. +func (h *Int32HashTable) VisitEntries(visit func(*entryInt32)) { + for _, e := range h.entries { + if e.Valid() { + visit(&e) + } + } +} + +// Int32MemoTable is a wrapper over the appropriate hashtable to provide an interface +// conforming to the MemoTable interface defined in the encoding package for general interactions +// regarding dictionaries. +type Int32MemoTable struct { + tbl *Int32HashTable + nullIdx int32 +} + +// NewInt32MemoTable returns a new memotable with num entries pre-allocated to reduce further +// allocations when inserting. +func NewInt32MemoTable(num int64) *Int32MemoTable { + return &Int32MemoTable{tbl: NewInt32HashTable(uint64(num)), nullIdx: KeyNotFound} +} + +// Reset allows this table to be re-used by dumping all the data currently in the table. +func (s *Int32MemoTable) Reset() { + s.tbl.Reset(32) + s.nullIdx = KeyNotFound +} + +// Size returns the current number of inserted elements into the table including if a null +// has been inserted. +func (s *Int32MemoTable) Size() int { + sz := int(s.tbl.size) + if _, ok := s.GetNull(); ok { + sz++ + } + return sz +} + +// GetNull returns the index of an inserted null or KeyNotFound along with a bool +// that will be true if found and false if not. +func (s *Int32MemoTable) GetNull() (int, bool) { + return int(s.nullIdx), s.nullIdx != KeyNotFound +} + +// GetOrInsertNull will return the index of the null entry or insert a null entry +// if one currently doesn't exist. The found value will be true if there was already +// a null in the table, and false if it inserted one. +func (s *Int32MemoTable) GetOrInsertNull() (idx int, found bool) { + idx, found = s.GetNull() + if !found { + idx = s.Size() + s.nullIdx = int32(idx) + } + return +} + +// CopyValues will copy the values from the memo table out into the passed in slice +// which must be of the appropriate type. +func (s *Int32MemoTable) CopyValues(out interface{}) { + s.CopyValuesSubset(0, out) +} + +// CopyValuesSubset is like CopyValues but only copies a subset of values starting +// at the provided start index +func (s *Int32MemoTable) CopyValuesSubset(start int, out interface{}) { + s.tbl.CopyValuesSubset(start, out.([]int32)) +} + +// Get returns the index of the requested value in the hash table or KeyNotFound +// along with a boolean indicating if it was found or not. +func (s *Int32MemoTable) Get(val interface{}) (int, bool) { + + h := hashInt(uint64(val.(int32)), 0) + if e, ok := s.tbl.Lookup(h, func(v int32) bool { return val.(int32) == v }); ok { + return int(e.payload.memoIdx), ok + } + return KeyNotFound, false +} + +// GetOrInsert will return the index of the specified value in the table, or insert the +// value into the table and return the new index. found indicates whether or not it already +// existed in the table (true) or was inserted by this call (false). +func (s *Int32MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + + h := hashInt(uint64(val.(int32)), 0) + e, ok := s.tbl.Lookup(h, func(v int32) bool { + return val.(int32) == v + }) + + if ok { + idx = int(e.payload.memoIdx) + found = true + } else { + idx = s.Size() + s.tbl.Insert(e, h, val.(int32), int32(idx)) + } + return +} + +type payloadInt64 struct { + val int64 + memoIdx int32 +} + +type entryInt64 struct { + h uint64 + payload payloadInt64 +} + +func (e entryInt64) Valid() bool { return e.h != sentinel } + +// Int64HashTable is a hashtable specifically for int64 that +// is utilized with the MemoTable to generalize interactions for easier +// implementation of dictionaries without losing performance. +type Int64HashTable struct { + cap uint64 + capMask uint64 + size uint64 + + entries []entryInt64 +} + +// NewInt64HashTable returns a new hash table for int64 values +// initialized with the passed in capacity or 32 whichever is larger. +func NewInt64HashTable(cap uint64) *Int64HashTable { + initCap := uint64(bitutil.NextPowerOf2(int(max(cap, 32)))) + ret := &Int64HashTable{cap: initCap, capMask: initCap - 1, size: 0} + ret.entries = make([]entryInt64, initCap) + return ret +} + +// Reset drops all of the values in this hash table and re-initializes it +// with the specified initial capacity as if by calling New, but without having +// to reallocate the object. +func (h *Int64HashTable) Reset(cap uint64) { + h.cap = uint64(bitutil.NextPowerOf2(int(max(cap, 32)))) + h.capMask = h.cap - 1 + h.size = 0 + h.entries = make([]entryInt64, h.cap) +} + +// CopyValues is used for copying the values out of the hash table into the +// passed in slice, in the order that they were first inserted +func (h *Int64HashTable) CopyValues(out []int64) { + h.CopyValuesSubset(0, out) +} + +// CopyValuesSubset copies a subset of the values in the hashtable out, starting +// with the value at start, in the order that they were inserted. +func (h *Int64HashTable) CopyValuesSubset(start int, out []int64) { + h.VisitEntries(func(e *entryInt64) { + idx := e.payload.memoIdx - int32(start) + if idx >= 0 { + out[idx] = e.payload.val + } + }) +} + +func (h *Int64HashTable) needUpsize() bool { return h.size*uint64(loadFactor) >= h.cap } + +func (Int64HashTable) fixHash(v uint64) uint64 { + if v == sentinel { + return 42 + } + return v +} + +// Lookup retrieves the entry for a given hash value assuming it's payload value returns +// true when passed to the cmp func. Returns a pointer to the entry for the given hash value, +// and a boolean as to whether it was found. It is not safe to use the pointer if the bool is false. +func (h *Int64HashTable) Lookup(v uint64, cmp func(int64) bool) (*entryInt64, bool) { + idx, ok := h.lookup(v, h.capMask, cmp) + return &h.entries[idx], ok +} + +func (h *Int64HashTable) lookup(v uint64, szMask uint64, cmp func(int64) bool) (uint64, bool) { + const perturbShift uint8 = 5 + + var ( + idx uint64 + perturb uint64 + e *entryInt64 + ) + + v = h.fixHash(v) + idx = v & szMask + perturb = (v >> uint64(perturbShift)) + 1 + + for { + e = &h.entries[idx] + if e.h == v && cmp(e.payload.val) { + return idx, true + } + + if e.h == sentinel { + return idx, false + } + + // perturbation logic inspired from CPython's set/dict object + // the goal is that all 64 bits of unmasked hash value eventually + // participate int he probing sequence, to minimize clustering + idx = (idx + perturb) & szMask + perturb = (perturb >> uint64(perturbShift)) + 1 + } +} + +func (h *Int64HashTable) upsize(newcap uint64) error { + newMask := newcap - 1 + + oldEntries := h.entries + h.entries = make([]entryInt64, newcap) + for _, e := range oldEntries { + if e.Valid() { + idx, _ := h.lookup(e.h, newMask, func(int64) bool { return false }) + h.entries[idx] = e + } + } + h.cap = newcap + h.capMask = newMask + return nil +} + +// Insert updates the given entry with the provided hash value, payload value and memo index. +// The entry pointer must have been retrieved via lookup in order to actually insert properly. +func (h *Int64HashTable) Insert(e *entryInt64, v uint64, val int64, memoIdx int32) error { + e.h = h.fixHash(v) + e.payload.val = val + e.payload.memoIdx = memoIdx + h.size++ + + if h.needUpsize() { + h.upsize(h.cap * uint64(loadFactor) * 2) + } + return nil +} + +// VisitEntries will call the passed in function on each *valid* entry in the hash table, +// a valid entry being one which has had a value inserted into it. +func (h *Int64HashTable) VisitEntries(visit func(*entryInt64)) { + for _, e := range h.entries { + if e.Valid() { + visit(&e) + } + } +} + +// Int64MemoTable is a wrapper over the appropriate hashtable to provide an interface +// conforming to the MemoTable interface defined in the encoding package for general interactions +// regarding dictionaries. +type Int64MemoTable struct { + tbl *Int64HashTable + nullIdx int32 +} + +// NewInt64MemoTable returns a new memotable with num entries pre-allocated to reduce further +// allocations when inserting. +func NewInt64MemoTable(num int64) *Int64MemoTable { + return &Int64MemoTable{tbl: NewInt64HashTable(uint64(num)), nullIdx: KeyNotFound} +} + +// Reset allows this table to be re-used by dumping all the data currently in the table. +func (s *Int64MemoTable) Reset() { + s.tbl.Reset(32) + s.nullIdx = KeyNotFound +} + +// Size returns the current number of inserted elements into the table including if a null +// has been inserted. +func (s *Int64MemoTable) Size() int { + sz := int(s.tbl.size) + if _, ok := s.GetNull(); ok { + sz++ + } + return sz +} + +// GetNull returns the index of an inserted null or KeyNotFound along with a bool +// that will be true if found and false if not. +func (s *Int64MemoTable) GetNull() (int, bool) { + return int(s.nullIdx), s.nullIdx != KeyNotFound +} + +// GetOrInsertNull will return the index of the null entry or insert a null entry +// if one currently doesn't exist. The found value will be true if there was already +// a null in the table, and false if it inserted one. +func (s *Int64MemoTable) GetOrInsertNull() (idx int, found bool) { + idx, found = s.GetNull() + if !found { + idx = s.Size() + s.nullIdx = int32(idx) + } + return +} + +// CopyValues will copy the values from the memo table out into the passed in slice +// which must be of the appropriate type. +func (s *Int64MemoTable) CopyValues(out interface{}) { + s.CopyValuesSubset(0, out) +} + +// CopyValuesSubset is like CopyValues but only copies a subset of values starting +// at the provided start index +func (s *Int64MemoTable) CopyValuesSubset(start int, out interface{}) { + s.tbl.CopyValuesSubset(start, out.([]int64)) +} + +// Get returns the index of the requested value in the hash table or KeyNotFound +// along with a boolean indicating if it was found or not. +func (s *Int64MemoTable) Get(val interface{}) (int, bool) { + + h := hashInt(uint64(val.(int64)), 0) + if e, ok := s.tbl.Lookup(h, func(v int64) bool { return val.(int64) == v }); ok { + return int(e.payload.memoIdx), ok + } + return KeyNotFound, false +} + +// GetOrInsert will return the index of the specified value in the table, or insert the +// value into the table and return the new index. found indicates whether or not it already +// existed in the table (true) or was inserted by this call (false). +func (s *Int64MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + + h := hashInt(uint64(val.(int64)), 0) + e, ok := s.tbl.Lookup(h, func(v int64) bool { + return val.(int64) == v + }) + + if ok { + idx = int(e.payload.memoIdx) + found = true + } else { + idx = s.Size() + s.tbl.Insert(e, h, val.(int64), int32(idx)) + } + return +} + +type payloadFloat32 struct { + val float32 + memoIdx int32 +} + +type entryFloat32 struct { + h uint64 + payload payloadFloat32 +} + +func (e entryFloat32) Valid() bool { return e.h != sentinel } + +// Float32HashTable is a hashtable specifically for float32 that +// is utilized with the MemoTable to generalize interactions for easier +// implementation of dictionaries without losing performance. +type Float32HashTable struct { + cap uint64 + capMask uint64 + size uint64 + + entries []entryFloat32 +} + +// NewFloat32HashTable returns a new hash table for float32 values +// initialized with the passed in capacity or 32 whichever is larger. +func NewFloat32HashTable(cap uint64) *Float32HashTable { + initCap := uint64(bitutil.NextPowerOf2(int(max(cap, 32)))) + ret := &Float32HashTable{cap: initCap, capMask: initCap - 1, size: 0} + ret.entries = make([]entryFloat32, initCap) + return ret +} + +// Reset drops all of the values in this hash table and re-initializes it +// with the specified initial capacity as if by calling New, but without having +// to reallocate the object. +func (h *Float32HashTable) Reset(cap uint64) { + h.cap = uint64(bitutil.NextPowerOf2(int(max(cap, 32)))) + h.capMask = h.cap - 1 + h.size = 0 + h.entries = make([]entryFloat32, h.cap) +} + +// CopyValues is used for copying the values out of the hash table into the +// passed in slice, in the order that they were first inserted +func (h *Float32HashTable) CopyValues(out []float32) { + h.CopyValuesSubset(0, out) +} + +// CopyValuesSubset copies a subset of the values in the hashtable out, starting +// with the value at start, in the order that they were inserted. +func (h *Float32HashTable) CopyValuesSubset(start int, out []float32) { + h.VisitEntries(func(e *entryFloat32) { + idx := e.payload.memoIdx - int32(start) + if idx >= 0 { + out[idx] = e.payload.val + } + }) +} + +func (h *Float32HashTable) needUpsize() bool { return h.size*uint64(loadFactor) >= h.cap } + +func (Float32HashTable) fixHash(v uint64) uint64 { + if v == sentinel { + return 42 + } + return v +} + +// Lookup retrieves the entry for a given hash value assuming it's payload value returns +// true when passed to the cmp func. Returns a pointer to the entry for the given hash value, +// and a boolean as to whether it was found. It is not safe to use the pointer if the bool is false. +func (h *Float32HashTable) Lookup(v uint64, cmp func(float32) bool) (*entryFloat32, bool) { + idx, ok := h.lookup(v, h.capMask, cmp) + return &h.entries[idx], ok +} + +func (h *Float32HashTable) lookup(v uint64, szMask uint64, cmp func(float32) bool) (uint64, bool) { + const perturbShift uint8 = 5 + + var ( + idx uint64 + perturb uint64 + e *entryFloat32 + ) + + v = h.fixHash(v) + idx = v & szMask + perturb = (v >> uint64(perturbShift)) + 1 + + for { + e = &h.entries[idx] + if e.h == v && cmp(e.payload.val) { + return idx, true + } + + if e.h == sentinel { + return idx, false + } + + // perturbation logic inspired from CPython's set/dict object + // the goal is that all 64 bits of unmasked hash value eventually + // participate int he probing sequence, to minimize clustering + idx = (idx + perturb) & szMask + perturb = (perturb >> uint64(perturbShift)) + 1 + } +} + +func (h *Float32HashTable) upsize(newcap uint64) error { + newMask := newcap - 1 + + oldEntries := h.entries + h.entries = make([]entryFloat32, newcap) + for _, e := range oldEntries { + if e.Valid() { + idx, _ := h.lookup(e.h, newMask, func(float32) bool { return false }) + h.entries[idx] = e + } + } + h.cap = newcap + h.capMask = newMask + return nil +} + +// Insert updates the given entry with the provided hash value, payload value and memo index. +// The entry pointer must have been retrieved via lookup in order to actually insert properly. +func (h *Float32HashTable) Insert(e *entryFloat32, v uint64, val float32, memoIdx int32) error { + e.h = h.fixHash(v) + e.payload.val = val + e.payload.memoIdx = memoIdx + h.size++ + + if h.needUpsize() { + h.upsize(h.cap * uint64(loadFactor) * 2) + } + return nil +} + +// VisitEntries will call the passed in function on each *valid* entry in the hash table, +// a valid entry being one which has had a value inserted into it. +func (h *Float32HashTable) VisitEntries(visit func(*entryFloat32)) { + for _, e := range h.entries { + if e.Valid() { + visit(&e) + } + } +} + +// Float32MemoTable is a wrapper over the appropriate hashtable to provide an interface +// conforming to the MemoTable interface defined in the encoding package for general interactions +// regarding dictionaries. +type Float32MemoTable struct { + tbl *Float32HashTable + nullIdx int32 +} + +// NewFloat32MemoTable returns a new memotable with num entries pre-allocated to reduce further +// allocations when inserting. +func NewFloat32MemoTable(num int64) *Float32MemoTable { + return &Float32MemoTable{tbl: NewFloat32HashTable(uint64(num)), nullIdx: KeyNotFound} +} + +// Reset allows this table to be re-used by dumping all the data currently in the table. +func (s *Float32MemoTable) Reset() { + s.tbl.Reset(32) + s.nullIdx = KeyNotFound +} + +// Size returns the current number of inserted elements into the table including if a null +// has been inserted. +func (s *Float32MemoTable) Size() int { + sz := int(s.tbl.size) + if _, ok := s.GetNull(); ok { + sz++ + } + return sz +} + +// GetNull returns the index of an inserted null or KeyNotFound along with a bool +// that will be true if found and false if not. +func (s *Float32MemoTable) GetNull() (int, bool) { + return int(s.nullIdx), s.nullIdx != KeyNotFound +} + +// GetOrInsertNull will return the index of the null entry or insert a null entry +// if one currently doesn't exist. The found value will be true if there was already +// a null in the table, and false if it inserted one. +func (s *Float32MemoTable) GetOrInsertNull() (idx int, found bool) { + idx, found = s.GetNull() + if !found { + idx = s.Size() + s.nullIdx = int32(idx) + } + return +} + +// CopyValues will copy the values from the memo table out into the passed in slice +// which must be of the appropriate type. +func (s *Float32MemoTable) CopyValues(out interface{}) { + s.CopyValuesSubset(0, out) +} + +// CopyValuesSubset is like CopyValues but only copies a subset of values starting +// at the provided start index +func (s *Float32MemoTable) CopyValuesSubset(start int, out interface{}) { + s.tbl.CopyValuesSubset(start, out.([]float32)) +} + +// Get returns the index of the requested value in the hash table or KeyNotFound +// along with a boolean indicating if it was found or not. +func (s *Float32MemoTable) Get(val interface{}) (int, bool) { + var cmp func(float32) bool + + if math.IsNaN(float64(val.(float32))) { + cmp = isNan32Cmp + // use consistent internal bit pattern for NaN regardless of the pattern + // that is passed to us. NaN is NaN is NaN + val = float32(math.NaN()) + } else { + cmp = func(v float32) bool { return val.(float32) == v } + } + + h := hashFloat32(val.(float32), 0) + if e, ok := s.tbl.Lookup(h, cmp); ok { + return int(e.payload.memoIdx), ok + } + return KeyNotFound, false +} + +// GetOrInsert will return the index of the specified value in the table, or insert the +// value into the table and return the new index. found indicates whether or not it already +// existed in the table (true) or was inserted by this call (false). +func (s *Float32MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + + var cmp func(float32) bool + + if math.IsNaN(float64(val.(float32))) { + cmp = isNan32Cmp + // use consistent internal bit pattern for NaN regardless of the pattern + // that is passed to us. NaN is NaN is NaN + val = float32(math.NaN()) + } else { + cmp = func(v float32) bool { return val.(float32) == v } + } + + h := hashFloat32(val.(float32), 0) + e, ok := s.tbl.Lookup(h, cmp) + + if ok { + idx = int(e.payload.memoIdx) + found = true + } else { + idx = s.Size() + s.tbl.Insert(e, h, val.(float32), int32(idx)) + } + return +} + +type payloadFloat64 struct { + val float64 + memoIdx int32 +} + +type entryFloat64 struct { + h uint64 + payload payloadFloat64 +} + +func (e entryFloat64) Valid() bool { return e.h != sentinel } + +// Float64HashTable is a hashtable specifically for float64 that +// is utilized with the MemoTable to generalize interactions for easier +// implementation of dictionaries without losing performance. +type Float64HashTable struct { + cap uint64 + capMask uint64 + size uint64 + + entries []entryFloat64 +} + +// NewFloat64HashTable returns a new hash table for float64 values +// initialized with the passed in capacity or 32 whichever is larger. +func NewFloat64HashTable(cap uint64) *Float64HashTable { + initCap := uint64(bitutil.NextPowerOf2(int(max(cap, 32)))) + ret := &Float64HashTable{cap: initCap, capMask: initCap - 1, size: 0} + ret.entries = make([]entryFloat64, initCap) + return ret +} + +// Reset drops all of the values in this hash table and re-initializes it +// with the specified initial capacity as if by calling New, but without having +// to reallocate the object. +func (h *Float64HashTable) Reset(cap uint64) { + h.cap = uint64(bitutil.NextPowerOf2(int(max(cap, 32)))) + h.capMask = h.cap - 1 + h.size = 0 + h.entries = make([]entryFloat64, h.cap) +} + +// CopyValues is used for copying the values out of the hash table into the +// passed in slice, in the order that they were first inserted +func (h *Float64HashTable) CopyValues(out []float64) { + h.CopyValuesSubset(0, out) +} + +// CopyValuesSubset copies a subset of the values in the hashtable out, starting +// with the value at start, in the order that they were inserted. +func (h *Float64HashTable) CopyValuesSubset(start int, out []float64) { + h.VisitEntries(func(e *entryFloat64) { + idx := e.payload.memoIdx - int32(start) + if idx >= 0 { + out[idx] = e.payload.val + } + }) +} + +func (h *Float64HashTable) needUpsize() bool { return h.size*uint64(loadFactor) >= h.cap } + +func (Float64HashTable) fixHash(v uint64) uint64 { + if v == sentinel { + return 42 + } + return v +} + +// Lookup retrieves the entry for a given hash value assuming it's payload value returns +// true when passed to the cmp func. Returns a pointer to the entry for the given hash value, +// and a boolean as to whether it was found. It is not safe to use the pointer if the bool is false. +func (h *Float64HashTable) Lookup(v uint64, cmp func(float64) bool) (*entryFloat64, bool) { + idx, ok := h.lookup(v, h.capMask, cmp) + return &h.entries[idx], ok +} + +func (h *Float64HashTable) lookup(v uint64, szMask uint64, cmp func(float64) bool) (uint64, bool) { + const perturbShift uint8 = 5 + + var ( + idx uint64 + perturb uint64 + e *entryFloat64 + ) + + v = h.fixHash(v) + idx = v & szMask + perturb = (v >> uint64(perturbShift)) + 1 + + for { + e = &h.entries[idx] + if e.h == v && cmp(e.payload.val) { + return idx, true + } + + if e.h == sentinel { + return idx, false + } + + // perturbation logic inspired from CPython's set/dict object + // the goal is that all 64 bits of unmasked hash value eventually + // participate int he probing sequence, to minimize clustering + idx = (idx + perturb) & szMask + perturb = (perturb >> uint64(perturbShift)) + 1 + } +} + +func (h *Float64HashTable) upsize(newcap uint64) error { + newMask := newcap - 1 + + oldEntries := h.entries + h.entries = make([]entryFloat64, newcap) + for _, e := range oldEntries { + if e.Valid() { + idx, _ := h.lookup(e.h, newMask, func(float64) bool { return false }) + h.entries[idx] = e + } + } + h.cap = newcap + h.capMask = newMask + return nil +} + +// Insert updates the given entry with the provided hash value, payload value and memo index. +// The entry pointer must have been retrieved via lookup in order to actually insert properly. +func (h *Float64HashTable) Insert(e *entryFloat64, v uint64, val float64, memoIdx int32) error { + e.h = h.fixHash(v) + e.payload.val = val + e.payload.memoIdx = memoIdx + h.size++ + + if h.needUpsize() { + h.upsize(h.cap * uint64(loadFactor) * 2) + } + return nil +} + +// VisitEntries will call the passed in function on each *valid* entry in the hash table, +// a valid entry being one which has had a value inserted into it. +func (h *Float64HashTable) VisitEntries(visit func(*entryFloat64)) { + for _, e := range h.entries { + if e.Valid() { + visit(&e) + } + } +} + +// Float64MemoTable is a wrapper over the appropriate hashtable to provide an interface +// conforming to the MemoTable interface defined in the encoding package for general interactions +// regarding dictionaries. +type Float64MemoTable struct { + tbl *Float64HashTable + nullIdx int32 +} + +// NewFloat64MemoTable returns a new memotable with num entries pre-allocated to reduce further +// allocations when inserting. +func NewFloat64MemoTable(num int64) *Float64MemoTable { + return &Float64MemoTable{tbl: NewFloat64HashTable(uint64(num)), nullIdx: KeyNotFound} +} + +// Reset allows this table to be re-used by dumping all the data currently in the table. +func (s *Float64MemoTable) Reset() { + s.tbl.Reset(32) + s.nullIdx = KeyNotFound +} + +// Size returns the current number of inserted elements into the table including if a null +// has been inserted. +func (s *Float64MemoTable) Size() int { + sz := int(s.tbl.size) + if _, ok := s.GetNull(); ok { + sz++ + } + return sz +} + +// GetNull returns the index of an inserted null or KeyNotFound along with a bool +// that will be true if found and false if not. +func (s *Float64MemoTable) GetNull() (int, bool) { + return int(s.nullIdx), s.nullIdx != KeyNotFound +} + +// GetOrInsertNull will return the index of the null entry or insert a null entry +// if one currently doesn't exist. The found value will be true if there was already +// a null in the table, and false if it inserted one. +func (s *Float64MemoTable) GetOrInsertNull() (idx int, found bool) { + idx, found = s.GetNull() + if !found { + idx = s.Size() + s.nullIdx = int32(idx) + } + return +} + +// CopyValues will copy the values from the memo table out into the passed in slice +// which must be of the appropriate type. +func (s *Float64MemoTable) CopyValues(out interface{}) { + s.CopyValuesSubset(0, out) +} + +// CopyValuesSubset is like CopyValues but only copies a subset of values starting +// at the provided start index +func (s *Float64MemoTable) CopyValuesSubset(start int, out interface{}) { + s.tbl.CopyValuesSubset(start, out.([]float64)) +} + +// Get returns the index of the requested value in the hash table or KeyNotFound +// along with a boolean indicating if it was found or not. +func (s *Float64MemoTable) Get(val interface{}) (int, bool) { + var cmp func(float64) bool + if math.IsNaN(val.(float64)) { + cmp = math.IsNaN + // use consistent internal bit pattern for NaN regardless of the pattern + // that is passed to us. NaN is NaN is NaN + val = math.NaN() + } else { + cmp = func(v float64) bool { return val.(float64) == v } + } + + h := hashFloat64(val.(float64), 0) + if e, ok := s.tbl.Lookup(h, cmp); ok { + return int(e.payload.memoIdx), ok + } + return KeyNotFound, false +} + +// GetOrInsert will return the index of the specified value in the table, or insert the +// value into the table and return the new index. found indicates whether or not it already +// existed in the table (true) or was inserted by this call (false). +func (s *Float64MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + + var cmp func(float64) bool + if math.IsNaN(val.(float64)) { + cmp = math.IsNaN + // use consistent internal bit pattern for NaN regardless of the pattern + // that is passed to us. NaN is NaN is NaN + val = math.NaN() + } else { + cmp = func(v float64) bool { return val.(float64) == v } + } + + h := hashFloat64(val.(float64), 0) + e, ok := s.tbl.Lookup(h, cmp) + + if ok { + idx = int(e.payload.memoIdx) + found = true + } else { + idx = s.Size() + s.tbl.Insert(e, h, val.(float64), int32(idx)) + } + return +} diff --git a/go/parquet/internal/hashing/xxh3_memo_table.gen.go.tmpl b/go/parquet/internal/hashing/xxh3_memo_table.gen.go.tmpl new file mode 100644 index 0000000000000..a56009bf8d370 --- /dev/null +++ b/go/parquet/internal/hashing/xxh3_memo_table.gen.go.tmpl @@ -0,0 +1,304 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hashing + +import ( + "github.com/apache/arrow/go/arrow/bitutil" +) + +{{range .In}} +type payload{{.Name}} struct { + val {{.name}} + memoIdx int32 +} + +type entry{{.Name}} struct { + h uint64 + payload payload{{.Name}} +} + +func (e entry{{.Name}}) Valid() bool { return e.h != sentinel } + +// {{.Name}}HashTable is a hashtable specifically for {{.name}} that +// is utilized with the MemoTable to generalize interactions for easier +// implementation of dictionaries without losing performance. +type {{.Name}}HashTable struct { + cap uint64 + capMask uint64 + size uint64 + + entries []entry{{.Name}} +} + +// New{{.Name}}HashTable returns a new hash table for {{.name}} values +// initialized with the passed in capacity or 32 whichever is larger. +func New{{.Name}}HashTable(cap uint64) *{{.Name}}HashTable { + initCap := uint64(bitutil.NextPowerOf2(int(max(cap, 32)))) + ret := &{{.Name}}HashTable{cap: initCap, capMask: initCap - 1, size: 0} + ret.entries = make([]entry{{.Name}}, initCap) + return ret +} + +// Reset drops all of the values in this hash table and re-initializes it +// with the specified initial capacity as if by calling New, but without having +// to reallocate the object. +func (h *{{.Name}}HashTable) Reset(cap uint64) { + h.cap = uint64(bitutil.NextPowerOf2(int(max(cap, 32)))) + h.capMask = h.cap - 1 + h.size = 0 + h.entries = make([]entry{{.Name}}, h.cap) +} + +// CopyValues is used for copying the values out of the hash table into the +// passed in slice, in the order that they were first inserted +func (h *{{.Name}}HashTable) CopyValues(out []{{.name}}) { + h.CopyValuesSubset(0, out) +} + +// CopyValuesSubset copies a subset of the values in the hashtable out, starting +// with the value at start, in the order that they were inserted. +func (h *{{.Name}}HashTable) CopyValuesSubset(start int, out []{{.name}}) { + h.VisitEntries(func(e *entry{{.Name}}) { + idx := e.payload.memoIdx - int32(start) + if idx >= 0 { + out[idx] = e.payload.val + } + }) +} + +func (h *{{.Name}}HashTable) needUpsize() bool { return h.size*uint64(loadFactor) >= h.cap } + +func ({{.Name}}HashTable) fixHash(v uint64) uint64 { + if v == sentinel { + return 42 + } + return v +} + +// Lookup retrieves the entry for a given hash value assuming it's payload value returns +// true when passed to the cmp func. Returns a pointer to the entry for the given hash value, +// and a boolean as to whether it was found. It is not safe to use the pointer if the bool is false. +func (h *{{.Name}}HashTable) Lookup(v uint64, cmp func({{.name}}) bool) (*entry{{.Name}}, bool) { + idx, ok := h.lookup(v, h.capMask, cmp) + return &h.entries[idx], ok +} + +func (h *{{.Name}}HashTable) lookup(v uint64, szMask uint64, cmp func({{.name}}) bool) (uint64, bool) { + const perturbShift uint8 = 5 + + var ( + idx uint64 + perturb uint64 + e *entry{{.Name}} + ) + + v = h.fixHash(v) + idx = v & szMask + perturb = (v >> uint64(perturbShift)) + 1 + + for { + e = &h.entries[idx] + if e.h == v && cmp(e.payload.val) { + return idx, true + } + + if e.h == sentinel { + return idx, false + } + + // perturbation logic inspired from CPython's set/dict object + // the goal is that all 64 bits of unmasked hash value eventually + // participate int he probing sequence, to minimize clustering + idx = (idx + perturb) & szMask + perturb = (perturb >> uint64(perturbShift)) + 1 + } +} + +func (h *{{.Name}}HashTable) upsize(newcap uint64) error { + newMask := newcap - 1 + + oldEntries := h.entries + h.entries = make([]entry{{.Name}}, newcap) + for _, e := range oldEntries { + if e.Valid() { + idx, _ := h.lookup(e.h, newMask, func({{.name}}) bool { return false }) + h.entries[idx] = e + } + } + h.cap = newcap + h.capMask = newMask + return nil +} + +// Insert updates the given entry with the provided hash value, payload value and memo index. +// The entry pointer must have been retrieved via lookup in order to actually insert properly. +func (h *{{.Name}}HashTable) Insert(e *entry{{.Name}}, v uint64, val {{.name}}, memoIdx int32) error { + e.h = h.fixHash(v) + e.payload.val = val + e.payload.memoIdx = memoIdx + h.size++ + + if h.needUpsize() { + h.upsize(h.cap * uint64(loadFactor) * 2) + } + return nil +} + +// VisitEntries will call the passed in function on each *valid* entry in the hash table, +// a valid entry being one which has had a value inserted into it. +func (h *{{.Name}}HashTable) VisitEntries(visit func(*entry{{.Name}})) { + for _, e := range h.entries { + if e.Valid() { + visit(&e) + } + } +} + +// {{.Name}}MemoTable is a wrapper over the appropriate hashtable to provide an interface +// conforming to the MemoTable interface defined in the encoding package for general interactions +// regarding dictionaries. +type {{.Name}}MemoTable struct { + tbl *{{.Name}}HashTable + nullIdx int32 +} + +// New{{.Name}}MemoTable returns a new memotable with num entries pre-allocated to reduce further +// allocations when inserting. +func New{{.Name}}MemoTable(num int64) *{{.Name}}MemoTable { + return &{{.Name}}MemoTable{tbl: New{{.Name}}HashTable(uint64(num)), nullIdx: KeyNotFound} +} + +// Reset allows this table to be re-used by dumping all the data currently in the table. +func (s *{{.Name}}MemoTable) Reset() { + s.tbl.Reset(32) + s.nullIdx = KeyNotFound +} + +// Size returns the current number of inserted elements into the table including if a null +// has been inserted. +func (s *{{.Name}}MemoTable) Size() int { + sz := int(s.tbl.size) + if _, ok := s.GetNull(); ok { + sz++ + } + return sz +} + +// GetNull returns the index of an inserted null or KeyNotFound along with a bool +// that will be true if found and false if not. +func (s *{{.Name}}MemoTable) GetNull() (int, bool) { + return int(s.nullIdx), s.nullIdx != KeyNotFound +} + +// GetOrInsertNull will return the index of the null entry or insert a null entry +// if one currently doesn't exist. The found value will be true if there was already +// a null in the table, and false if it inserted one. +func (s *{{.Name}}MemoTable) GetOrInsertNull() (idx int, found bool) { + idx, found = s.GetNull() + if !found { + idx = s.Size() + s.nullIdx = int32(idx) + } + return +} + +// CopyValues will copy the values from the memo table out into the passed in slice +// which must be of the appropriate type. +func (s *{{.Name}}MemoTable) CopyValues(out interface{}) { + s.CopyValuesSubset(0, out) +} + +// CopyValuesSubset is like CopyValues but only copies a subset of values starting +// at the provided start index +func (s *{{.Name}}MemoTable) CopyValuesSubset(start int, out interface{}) { + s.tbl.CopyValuesSubset(start, out.([]{{.name}})) +} + +// Get returns the index of the requested value in the hash table or KeyNotFound +// along with a boolean indicating if it was found or not. +func (s *{{.Name}}MemoTable) Get(val interface{}) (int, bool) { +{{if or (eq .Name "Int32") (eq .Name "Int64") }} + h := hashInt(uint64(val.({{.name}})), 0) + if e, ok := s.tbl.Lookup(h, func(v {{.name}}) bool { return val.({{.name}}) == v }); ok { +{{ else -}} + var cmp func({{.name}}) bool + {{if eq .Name "Float32"}} + if math.IsNaN(float64(val.(float32))) { + cmp = isNan32Cmp + // use consistent internal bit pattern for NaN regardless of the pattern + // that is passed to us. NaN is NaN is NaN + val = float32(math.NaN()) + {{ else -}} + if math.IsNaN(val.(float64)) { + cmp = math.IsNaN + // use consistent internal bit pattern for NaN regardless of the pattern + // that is passed to us. NaN is NaN is NaN + val = math.NaN() + {{end -}} + } else { + cmp = func(v {{.name}}) bool { return val.({{.name}}) == v } + } + + h := hash{{.Name}}(val.({{.name}}), 0) + if e, ok := s.tbl.Lookup(h, cmp); ok { +{{ end -}} + return int(e.payload.memoIdx), ok + } + return KeyNotFound, false +} + +// GetOrInsert will return the index of the specified value in the table, or insert the +// value into the table and return the new index. found indicates whether or not it already +// existed in the table (true) or was inserted by this call (false). +func (s *{{.Name}}MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + {{if or (eq .Name "Int32") (eq .Name "Int64") }} + h := hashInt(uint64(val.({{.name}})), 0) + e, ok := s.tbl.Lookup(h, func(v {{.name}}) bool { + return val.({{.name}}) == v + }) +{{ else }} + var cmp func({{.name}}) bool + {{if eq .Name "Float32"}} + if math.IsNaN(float64(val.(float32))) { + cmp = isNan32Cmp + // use consistent internal bit pattern for NaN regardless of the pattern + // that is passed to us. NaN is NaN is NaN + val = float32(math.NaN()) + {{ else -}} + if math.IsNaN(val.(float64)) { + cmp = math.IsNaN + // use consistent internal bit pattern for NaN regardless of the pattern + // that is passed to us. NaN is NaN is NaN + val = math.NaN() + {{end -}} + } else { + cmp = func(v {{.name}}) bool { return val.({{.name}}) == v } + } + + h := hash{{.Name}}(val.({{.name}}), 0) + e, ok := s.tbl.Lookup(h, cmp) +{{ end }} + if ok { + idx = int(e.payload.memoIdx) + found = true + } else { + idx = s.Size() + s.tbl.Insert(e, h, val.({{.name}}), int32(idx)) + } + return +} +{{end}} diff --git a/go/parquet/internal/hashing/xxh3_memo_table.go b/go/parquet/internal/hashing/xxh3_memo_table.go new file mode 100644 index 0000000000000..dd1ee6cf58f0e --- /dev/null +++ b/go/parquet/internal/hashing/xxh3_memo_table.go @@ -0,0 +1,386 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package hashing provides utilities for and an implementation of a hash +// table which is more performant than the default go map implementation +// by leveraging xxh3 and some custom hash functions. +package hashing + +import ( + "bytes" + "math" + "math/bits" + "reflect" + "unsafe" + + "github.com/apache/arrow/go/arrow" + "github.com/apache/arrow/go/arrow/array" + "github.com/apache/arrow/go/arrow/memory" + "github.com/apache/arrow/go/parquet" + + "github.com/zeebo/xxh3" +) + +//go:generate go run ../../../arrow/_tools/tmpl/main.go -i -data=types.tmpldata xxh3_memo_table.gen.go.tmpl + +func hashInt(val uint64, alg uint64) uint64 { + // Two of xxhash's prime multipliers (which are chosen for their + // bit dispersion properties) + var multipliers = [2]uint64{11400714785074694791, 14029467366897019727} + // Multiplying by the prime number mixes the low bits into the high bits, + // then byte-swapping (which is a single CPU instruction) allows the + // combined high and low bits to participate in the initial hash table index. + return bits.ReverseBytes64(multipliers[alg] * val) +} + +func hashFloat32(val float32, alg uint64) uint64 { + // grab the raw byte pattern of the + bt := *(*[4]byte)(unsafe.Pointer(&val)) + x := uint64(*(*uint32)(unsafe.Pointer(&bt[0]))) + hx := hashInt(x, alg) + hy := hashInt(x, alg^1) + return 4 ^ hx ^ hy +} + +func hashFloat64(val float64, alg uint64) uint64 { + bt := *(*[8]byte)(unsafe.Pointer(&val)) + hx := hashInt(uint64(*(*uint32)(unsafe.Pointer(&bt[4]))), alg) + hy := hashInt(uint64(*(*uint32)(unsafe.Pointer(&bt[0]))), alg^1) + return 8 ^ hx ^ hy +} + +func hashString(val string, alg uint64) uint64 { + buf := *(*[]byte)(unsafe.Pointer(&val)) + (*reflect.SliceHeader)(unsafe.Pointer(&buf)).Cap = len(val) + return hash(buf, alg) +} + +// prime constants used for slightly increasing the hash quality further +var exprimes = [2]uint64{1609587929392839161, 9650029242287828579} + +// for smaller amounts of bytes this is faster than even calling into +// xxh3 to do the hash, so we specialize in order to get the benefits +// of that performance. +func hash(b []byte, alg uint64) uint64 { + n := uint32(len(b)) + if n <= 16 { + switch { + case n > 8: + // 8 < length <= 16 + // apply same principle as above, but as two 64-bit ints + x := *(*uint64)(unsafe.Pointer(&b[n-8])) + y := *(*uint64)(unsafe.Pointer(&b[0])) + hx := hashInt(x, alg) + hy := hashInt(y, alg^1) + return uint64(n) ^ hx ^ hy + case n >= 4: + // 4 < length <= 8 + // we can read the bytes as two overlapping 32-bit ints, apply different + // hash functions to each in parallel + // then xor the results + x := *(*uint32)(unsafe.Pointer(&b[n-4])) + y := *(*uint32)(unsafe.Pointer(&b[0])) + hx := hashInt(uint64(x), alg) + hy := hashInt(uint64(y), alg^1) + return uint64(n) ^ hx ^ hy + case n > 0: + x := uint32((n << 24) ^ (uint32(b[0]) << 16) ^ (uint32(b[n/2]) << 8) ^ uint32(b[n-1])) + return hashInt(uint64(x), alg) + case n == 0: + return 1 + } + } + + // increase differentiation enough to improve hash quality + return xxh3.Hash(b) + exprimes[alg] +} + +const ( + sentinel uint64 = 0 + loadFactor int64 = 2 +) + +func max(a, b uint64) uint64 { + if a > b { + return a + } + return b +} + +var isNan32Cmp = func(v float32) bool { return math.IsNaN(float64(v)) } + +// KeyNotFound is the constant returned by memo table functions when a key isn't found in the table +const KeyNotFound = -1 + +// BinaryMemoTable is our hashtable for binary data using the BinaryBuilder +// to construct the actual data in an easy to pass around way with minimal copies +// while using a hash table to keep track of the indexes into the dictionary that +// is created as we go. +type BinaryMemoTable struct { + tbl *Int32HashTable + builder *array.BinaryBuilder + nullIdx int +} + +// NewBinaryMemoTable returns a hash table for Binary data, the passed in allocator will +// be utilized for the BinaryBuilder, if nil then memory.DefaultAllocator will be used. +// initial and valuesize can be used to pre-allocate the table to reduce allocations. With +// initial being the initial number of entries to allocate for and valuesize being the starting +// amount of space allocated for writing the actual binary data. +func NewBinaryMemoTable(mem memory.Allocator, initial, valuesize int) *BinaryMemoTable { + if mem == nil { + mem = memory.DefaultAllocator + } + bldr := array.NewBinaryBuilder(mem, arrow.BinaryTypes.Binary) + bldr.Reserve(int(initial)) + datasize := valuesize + if datasize <= 0 { + datasize = initial * 4 + } + bldr.ReserveData(datasize) + return &BinaryMemoTable{tbl: NewInt32HashTable(uint64(initial)), builder: bldr, nullIdx: KeyNotFound} +} + +// Reset dumps all of the data in the table allowing it to be reutilized. +func (s *BinaryMemoTable) Reset() { + s.tbl.Reset(32) + s.builder.NewArray().Release() + s.builder.Reserve(int(32)) + s.builder.ReserveData(int(32) * 4) + s.nullIdx = KeyNotFound +} + +// GetNull returns the index of a null that has been inserted into the table or +// KeyNotFound. The bool returned will be true if there was a null inserted into +// the table, and false otherwise. +func (s *BinaryMemoTable) GetNull() (int, bool) { + return int(s.nullIdx), s.nullIdx != KeyNotFound +} + +// Size returns the current size of the memo table including the null value +// if one has been inserted. +func (s *BinaryMemoTable) Size() int { + sz := int(s.tbl.size) + if _, ok := s.GetNull(); ok { + sz++ + } + return sz +} + +// helper function to easily return a byte slice for any given value +// regardless of the type if it's a []byte, parquet.ByteArray, +// parquet.FixedLenByteArray or string. +func (BinaryMemoTable) valAsByteSlice(val interface{}) []byte { + switch v := val.(type) { + case []byte: + return v + case parquet.ByteArray: + return *(*[]byte)(unsafe.Pointer(&v)) + case parquet.FixedLenByteArray: + return *(*[]byte)(unsafe.Pointer(&v)) + case string: + return (*(*[]byte)(unsafe.Pointer(&v)))[:len(v):len(v)] + default: + panic("invalid type for binarymemotable") + } +} + +// helper function to get the hash value regardless of the underlying binary type +func (BinaryMemoTable) getHash(val interface{}) uint64 { + switch v := val.(type) { + case string: + return hashString(v, 0) + case []byte: + return hash(v, 0) + case parquet.ByteArray: + return hash(*(*[]byte)(unsafe.Pointer(&v)), 0) + case parquet.FixedLenByteArray: + return hash(*(*[]byte)(unsafe.Pointer(&v)), 0) + default: + panic("invalid type for binarymemotable") + } +} + +// helper function to append the given value to the builder regardless +// of the underlying binary type. +func (b *BinaryMemoTable) appendVal(val interface{}) { + switch v := val.(type) { + case string: + b.builder.AppendString(v) + case []byte: + b.builder.Append(v) + case parquet.ByteArray: + b.builder.Append(*(*[]byte)(unsafe.Pointer(&v))) + case parquet.FixedLenByteArray: + b.builder.Append(*(*[]byte)(unsafe.Pointer(&v))) + } +} + +func (b *BinaryMemoTable) lookup(h uint64, val []byte) (*entryInt32, bool) { + return b.tbl.Lookup(h, func(i int32) bool { + return bytes.Equal(val, b.builder.Value(int(i))) + }) +} + +// Get returns the index of the specified value in the table or KeyNotFound, +// and a boolean indicating whether it was found in the table. +func (b *BinaryMemoTable) Get(val interface{}) (int, bool) { + if p, ok := b.lookup(b.getHash(val), b.valAsByteSlice(val)); ok { + return int(p.payload.val), ok + } + return KeyNotFound, false +} + +// GetOrInsert returns the index of the given value in the table, if not found +// it is inserted into the table. The return value 'found' indicates whether the value +// was found in the table (true) or inserted (false) along with any possible error. +func (b *BinaryMemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + h := b.getHash(val) + p, found := b.lookup(h, b.valAsByteSlice(val)) + if found { + idx = int(p.payload.val) + } else { + idx = b.Size() + b.appendVal(val) + b.tbl.Insert(p, h, int32(idx), -1) + } + return +} + +// GetOrInsertNull retrieves the index of a null in the table or inserts +// null into the table, returning the index and a boolean indicating if it was +// found in the table (true) or was inserted (false). +func (b *BinaryMemoTable) GetOrInsertNull() (idx int, found bool) { + idx, found = b.GetNull() + if !found { + idx = b.Size() + b.nullIdx = idx + b.builder.AppendNull() + } + return +} + +// helper function to get the offset into the builder data for a given +// index value. +func (b *BinaryMemoTable) findOffset(idx int) uintptr { + val := b.builder.Value(idx) + for len(val) == 0 { + idx++ + if idx >= b.builder.Len() { + break + } + val = b.builder.Value(idx) + } + if len(val) != 0 { + return uintptr(unsafe.Pointer(&val[0])) + } + return uintptr(b.builder.DataLen()) + b.findOffset(0) +} + +// CopyOffsets copies the list of offsets into the passed in slice, the offsets +// being the start and end values of the underlying allocated bytes in the builder +// for the individual values of the table. out should be at least sized to Size()+1 +func (b *BinaryMemoTable) CopyOffsets(out []int8) { + b.CopyOffsetsSubset(0, out) +} + +// CopyOffsetsSubset is like CopyOffsets but instead of copying all of the offsets, +// it gets a subset of the offsets in the table starting at the index provided by "start". +func (b *BinaryMemoTable) CopyOffsetsSubset(start int, out []int8) { + if b.builder.Len() <= start { + return + } + + first := b.findOffset(0) + delta := b.findOffset(start) + for i := start; i < b.Size(); i++ { + offset := int8(b.findOffset(i) - delta) + out[i-start] = offset + } + + out[b.Size()-start] = int8(b.builder.DataLen() - int(delta) - int(first)) +} + +// CopyValues copies the raw binary data bytes out, out should be a []byte +// with at least ValuesSize bytes allocated to copy into. +func (b *BinaryMemoTable) CopyValues(out interface{}) { + b.CopyValuesSubset(0, out) +} + +// CopyValuesSubset copies the raw binary data bytes out starting with the value +// at the index start, out should be a []byte with at least ValuesSize bytes allocated +func (b *BinaryMemoTable) CopyValuesSubset(start int, out interface{}) { + var ( + first = b.findOffset(0) + offset = b.findOffset(int(start)) + length = b.builder.DataLen() - int(offset-first) + ) + + outval := out.([]byte) + copy(outval, b.builder.Value(start)[0:length]) +} + +// CopyFixedWidthValues exists to cope with the fact that the table doesn't keep +// track of the fixed width when inserting the null value the databuffer holds a +// zero length byte slice for the null value (if found) +func (b *BinaryMemoTable) CopyFixedWidthValues(start, width int, out []byte) { + if start >= b.Size() { + return + } + + null, exists := b.GetNull() + if !exists || null < start { + // nothing to skip, proceed as usual + b.CopyValuesSubset(start, out) + return + } + + var ( + leftOffset = b.findOffset(start) + nullOffset = b.findOffset(null) + leftSize = nullOffset - leftOffset + ) + + if leftSize > 0 { + copy(out, b.builder.Value(start)[0:leftSize]) + } + + rightSize := b.ValuesSize() - int(nullOffset) + if rightSize > 0 { + // skip the null fixed size value + copy(out[int(leftSize)+width:], b.builder.Value(int(nullOffset))[0:rightSize]) + } +} + +// VisitValues exists to run the visitFn on each value currently in the hash table. +func (b *BinaryMemoTable) VisitValues(start int, visitFn func([]byte)) { + for i := int(start); i < b.Size(); i++ { + visitFn(b.builder.Value(i)) + } +} + +// Release is used to tell the underlying builder that it can release the memory allocated +// when the reference count reaches 0, this is safe to be called from multiple goroutines +// simultaneously +func (b *BinaryMemoTable) Release() { b.builder.Release() } + +// Retain increases the ref count, it is safe to call it from multiple goroutines +// simultaneously. +func (b *BinaryMemoTable) Retain() { b.builder.Retain() } + +// ValuesSize returns the current total size of all the raw bytes that have been inserted +// into the memotable so far. +func (b *BinaryMemoTable) ValuesSize() int { return b.builder.DataLen() } diff --git a/go/parquet/internal/utils/bit_reader_test.go b/go/parquet/internal/utils/bit_reader_test.go index 4b246e13fc228..b076a1a30d5f8 100644 --- a/go/parquet/internal/utils/bit_reader_test.go +++ b/go/parquet/internal/utils/bit_reader_test.go @@ -175,18 +175,36 @@ func TestMixedValues(t *testing.T) { } func TestZigZag(t *testing.T) { - testvals := []int64{0, 1, 1234, -1, -1234, math.MaxInt32, -math.MaxInt32} + testvals := []struct { + val int64 + exp [10]byte + }{ + {0, [...]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, + {1, [...]byte{2, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, + {1234, [...]byte{164, 19, 0, 0, 0, 0, 0, 0, 0, 0}}, + {-1, [...]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, + {-1234, [...]byte{163, 19, 0, 0, 0, 0, 0, 0, 0, 0}}, + {math.MaxInt32, [...]byte{254, 255, 255, 255, 15, 0, 0, 0, 0, 0}}, + {-math.MaxInt32, [...]byte{253, 255, 255, 255, 15, 0, 0, 0, 0, 0}}, + {math.MinInt32, [...]byte{255, 255, 255, 255, 15, 0, 0, 0, 0, 0}}, + {math.MaxInt64, [...]byte{254, 255, 255, 255, 255, 255, 255, 255, 255, 1}}, + {-math.MaxInt64, [...]byte{253, 255, 255, 255, 255, 255, 255, 255, 255, 1}}, + {math.MinInt64, [...]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 1}}, + } + for _, v := range testvals { - t.Run(strconv.Itoa(int(v)), func(t *testing.T) { + t.Run(strconv.Itoa(int(v.val)), func(t *testing.T) { var buf [binary.MaxVarintLen64]byte wrtr := utils.NewBitWriter(utils.NewWriterAtBuffer(buf[:])) - assert.True(t, wrtr.WriteZigZagVlqInt(v)) + assert.True(t, wrtr.WriteZigZagVlqInt(v.val)) wrtr.Flush(false) + assert.Equal(t, v.exp, buf) + rdr := utils.NewBitReader(bytes.NewReader(buf[:])) val, ok := rdr.GetZigZagVlqInt() assert.True(t, ok) - assert.EqualValues(t, v, val) + assert.EqualValues(t, v.val, val) }) } } @@ -221,7 +239,7 @@ func (r *RLETestSuite) ValidateRle(vals []uint64, width int, expected []byte, ex enc := utils.NewRleEncoder(utils.NewWriterAtBuffer(buf), width) for _, val := range vals { - r.True(enc.Put(val)) + r.NoError(enc.Put(val)) } encoded := enc.Flush() if explen != -1 { @@ -450,7 +468,7 @@ func (r *RLERandomSuite) checkRoundTrip(vals []uint64, width int) bool { res := r.Run("encode values", func() { enc := utils.NewRleEncoder(utils.NewWriterAtBuffer(buf), width) for idx, val := range vals { - r.Require().Truef(enc.Put(val), "encoding idx: %d", idx) + r.Require().NoErrorf(enc.Put(val), "encoding idx: %d", idx) } encoded = enc.Flush() }) @@ -485,7 +503,7 @@ func (r *RLERandomSuite) checkRoundTripSpaced(vals array.Interface, width int) { case *array.Int32: for i := 0; i < v.Len(); i++ { if v.IsValid(i) { - r.Require().True(encoder.Put(uint64(v.Value(i)))) + r.Require().NoError(encoder.Put(uint64(v.Value(i)))) } } } diff --git a/go/parquet/internal/utils/rle.go b/go/parquet/internal/utils/rle.go index d31dd1d13714a..72dbc36767949 100644 --- a/go/parquet/internal/utils/rle.go +++ b/go/parquet/internal/utils/rle.go @@ -492,14 +492,14 @@ func (r *RleEncoder) Flush() int { return r.w.Written() } -func (r *RleEncoder) flushBuffered(done bool) { +func (r *RleEncoder) flushBuffered(done bool) (err error) { if r.repCount >= 8 { // clear buffered values. they are part of the repeated run now and we // don't want to flush them as literals r.buffer = r.buffer[:0] if r.litCount != 0 { // there was current literal run. all values flushed but need to update the indicator - r.flushLiteral(true) + err = r.flushLiteral(true) } return } @@ -509,20 +509,23 @@ func (r *RleEncoder) flushBuffered(done bool) { if ngroups+1 >= (1 << 6) { // we need to start a new literal run because the indicator byte we've reserved // cannot store any more values - r.flushLiteral(true) + err = r.flushLiteral(true) } else { - r.flushLiteral(done) + err = r.flushLiteral(done) } r.repCount = 0 + return } -func (r *RleEncoder) flushLiteral(updateIndicator bool) { +func (r *RleEncoder) flushLiteral(updateIndicator bool) (err error) { if r.literalIndicatorOffset == -1 { r.literalIndicatorOffset = r.w.ReserveBytes(1) } for _, val := range r.buffer { - r.w.WriteValue(val, uint(r.BitWidth)) + if err = r.w.WriteValue(val, uint(r.BitWidth)); err != nil { + return + } } r.buffer = r.buffer[:0] @@ -532,35 +535,39 @@ func (r *RleEncoder) flushLiteral(updateIndicator bool) { // the logic makes sure we flush literal runs often enough to not overrun the 1 byte. ngroups := r.litCount / 8 r.indicatorBuffer[0] = byte((ngroups << 1) | 1) - r.w.WriteAt(r.indicatorBuffer[:], int64(r.literalIndicatorOffset)) + _, err = r.w.WriteAt(r.indicatorBuffer[:], int64(r.literalIndicatorOffset)) r.literalIndicatorOffset = -1 r.litCount = 0 } + return } -func (r *RleEncoder) flushRepeated() { +func (r *RleEncoder) flushRepeated() (ret bool) { indicator := r.repCount << 1 - r.w.WriteVlqInt(uint64(indicator)) - r.w.WriteAligned(r.curVal, int(bitutil.BytesForBits(int64(r.BitWidth)))) + + ret = r.w.WriteVlqInt(uint64(indicator)) + ret = ret && r.w.WriteAligned(r.curVal, int(bitutil.BytesForBits(int64(r.BitWidth)))) r.repCount = 0 r.buffer = r.buffer[:0] + return } // Put buffers input values 8 at a time. after seeing all 8 values, // it decides whether they should be encoded as a literal or repeated run. -func (r *RleEncoder) Put(value uint64) bool { - +func (r *RleEncoder) Put(value uint64) error { if r.curVal == value { r.repCount++ if r.repCount > 8 { // this is just a continuation of the current run, no need to buffer the values // NOTE this is the fast path for long repeated runs - return true + return nil } } else { if r.repCount >= 8 { - r.flushRepeated() + if !r.flushRepeated() { + return xerrors.New("failed to flush repeated value") + } } r.repCount = 1 r.curVal = value @@ -568,9 +575,9 @@ func (r *RleEncoder) Put(value uint64) bool { r.buffer = append(r.buffer, value) if len(r.buffer) == 8 { - r.flushBuffered(false) + return r.flushBuffered(false) } - return true + return nil } func (r *RleEncoder) Clear() {