/
iter.go
executable file
·98 lines (90 loc) · 2.76 KB
/
iter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
// Copyright 2021 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqlshim
import (
dsql "database/sql"
"io"
"reflect"
"time"
"github.com/Rock-liyi/p2pdb-store/sql"
)
// mysqlIter wraps an iterator returned by the MySQL connection.
type mysqlIter struct {
rows *dsql.Rows
types []reflect.Type
}
var _ sql.RowIter = mysqlIter{}
// newMySQLIter returns a new mysqlIter.
func newMySQLIter(rows *dsql.Rows) mysqlIter {
columnTypes, err := rows.ColumnTypes()
if err != nil {
panic(err)
}
types := make([]reflect.Type, len(columnTypes))
for i, columnType := range columnTypes {
scanType := columnType.ScanType()
switch scanType {
case reflect.TypeOf(dsql.RawBytes{}):
scanType = reflect.TypeOf("")
case reflect.TypeOf(dsql.NullBool{}):
scanType = reflect.TypeOf(true)
//case reflect.TypeOf(dsql.NullByte{}): // Not supported in go 1.15, need to upgrade to 1.17
// scanType = reflect.TypeOf(byte(0))
case reflect.TypeOf(dsql.NullFloat64{}):
scanType = reflect.TypeOf(float64(0))
//case reflect.TypeOf(dsql.NullInt16{}): // Not supported in go 1.15, need to upgrade to 1.17
// scanType = reflect.TypeOf(int16(0))
case reflect.TypeOf(dsql.NullInt32{}):
scanType = reflect.TypeOf(int32(0))
case reflect.TypeOf(dsql.NullInt64{}):
scanType = reflect.TypeOf(int64(0))
case reflect.TypeOf(dsql.NullString{}):
scanType = reflect.TypeOf("")
case reflect.TypeOf(dsql.NullTime{}):
scanType = reflect.TypeOf(time.Time{})
}
types[i] = scanType
}
return mysqlIter{rows, types}
}
// Next implements the interface sql.RowIter.
func (m mysqlIter) Next(ctx *sql.Context) (sql.Row, error) {
if m.rows.Next() {
output := make(sql.Row, len(m.types))
for i, typ := range m.types {
output[i] = reflect.New(typ).Interface()
}
err := m.rows.Scan(output...)
if err != nil {
return nil, err
}
for i, val := range output {
reflectVal := reflect.ValueOf(val)
if reflectVal.IsNil() {
output[i] = nil
} else {
output[i] = reflectVal.Elem().Interface()
if byteSlice, ok := val.([]byte); ok {
output[i] = string(byteSlice)
}
}
}
return output, nil
}
return nil, io.EOF
}
// Close implements the interface sql.RowIter.
func (m mysqlIter) Close(ctx *sql.Context) error {
return m.rows.Close()
}