Skip to content

Commit

Permalink
Support scan of custom types implementing sql.Scanner into array/tupl…
Browse files Browse the repository at this point in the history
…e columns (#1129)

* fix(lib/column): support scan of custom types implementing sql.Scanner

Signed-off-by: Leonardo Di Donato <leodidonato@gmail.com>

* test: test for issue 1128

Signed-off-by: Leonardo Di Donato <leodidonato@gmail.com>

---------

Signed-off-by: Leonardo Di Donato <leodidonato@gmail.com>
  • Loading branch information
leodido committed Oct 26, 2023
1 parent 5f681e8 commit 25571c8
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 3 deletions.
17 changes: 14 additions & 3 deletions lib/column/tuple.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
package column

import (
"database/sql"
"fmt"
"github.com/ClickHouse/ch-go/proto"
"github.com/google/uuid"
"github.com/shopspring/decimal"
"net"
"reflect"
"strings"
"time"

"github.com/ClickHouse/ch-go/proto"
"github.com/google/uuid"
"github.com/shopspring/decimal"
)

type Tuple struct {
Expand Down Expand Up @@ -209,6 +211,15 @@ func setJSONFieldValue(field reflect.Value, value reflect.Value) error {
return nil
}

// check if our target implements sql.Scanner
sqlScanner := reflect.TypeOf((*sql.Scanner)(nil)).Elem()
if fieldAddr := field.Addr(); field.Kind() != reflect.Ptr && fieldAddr.Type().Implements(sqlScanner) {
returns := fieldAddr.MethodByName("Scan").Call([]reflect.Value{value})
if len(returns) > 0 && returns[0].IsNil() {
return nil
}
}

return &ColumnConverterError{
Op: "ScanRow",
To: fmt.Sprintf("%T", field.Interface()),
Expand Down
88 changes: 88 additions & 0 deletions tests/issues/1128_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package issues

import (
"context"
"fmt"
"testing"

"github.com/ClickHouse/clickhouse-go/v2"
clickhouse_tests "github.com/ClickHouse/clickhouse-go/v2/tests"
"github.com/stretchr/testify/require"
)

const (
A SomeUint64AsString = iota + 1
B
C
)

type SomeUint64AsString uint64

func (f *SomeUint64AsString) Scan(src any) error {
if t, ok := src.(uint64); ok {
*f = SomeUint64AsString(t)
return nil
}
if t, ok := src.(string); ok {
switch t {
case "A":
*f = A
case "B":
*f = B
case "C":
*f = C
default:
return fmt.Errorf("cannot scan %s into SomeUint64AsString", t)
}
return nil
}

return fmt.Errorf("cannot scan %T into SomeUint64AsString", src)
}

func (f SomeUint64AsString) String() string {
switch f {
case A:
return "A"
case B:
return "B"
case C:
return "C"
}
return ""
}

type Test struct {
Col1 []SomeUint64AsString `ch:"Col1"`
}

func Test1128(t *testing.T) {
var (
conn, err = clickhouse_tests.GetConnection("issues", clickhouse.Settings{
"max_execution_time": 60,
"allow_experimental_object_type": true,
}, nil, &clickhouse.Compression{
Method: clickhouse.CompressionLZ4,
})
)
ctx := context.Background()
require.NoError(t, err)
const ddl = "CREATE TABLE test_1128 (Col1 Array(String)) Engine MergeTree() ORDER BY tuple()"
require.NoError(t, conn.Exec(ctx, ddl))
defer func() {
conn.Exec(ctx, "DROP TABLE IF EXISTS test_1128")
}()

batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_1128")
require.NoError(t, err)

data := Test{
Col1: []SomeUint64AsString{A, B, C},
}
require.NoError(t, batch.AppendStruct(&data))
require.NoError(t, batch.Send())

var res Test
require.NoError(t, conn.QueryRow(ctx, "SELECT * FROM test_1128").ScanStruct(&res))
require.Equal(t, []SomeUint64AsString{A, B, C}, res.Col1)
}

0 comments on commit 25571c8

Please sign in to comment.