Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/driver value conversion #19

Merged
merged 2 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion gonull.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"reflect"
)

Expand Down Expand Up @@ -68,7 +69,54 @@ func (n Nullable[T]) Value() (driver.Value, error) {
if valuer, ok := interface{}(n.Val).(driver.Valuer); ok {
return valuer.Value()
}
return n.Val, nil

return convertToDriverValue(n.Val)
}

func convertToDriverValue(v any) (driver.Value, error) {
if valuer, ok := v.(driver.Valuer); ok {
return valuer.Value()
}

rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Pointer:
if rv.IsNil() {
return nil, nil
}
return convertToDriverValue(rv.Elem().Interface())

case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return rv.Int(), nil

case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
return int64(rv.Uint()), nil

case reflect.Uint64:
u64 := rv.Uint()
if u64 >= 1<<63 {
return nil, fmt.Errorf("uint64 values with high bit set are not supported")
}
return int64(u64), nil

case reflect.Float32, reflect.Float64:
return rv.Float(), nil

case reflect.Bool:
return rv.Bool(), nil

case reflect.Slice:
if rv.Type().Elem().Kind() == reflect.Uint8 {
return rv.Bytes(), nil
}
return nil, fmt.Errorf("unsupported slice type: %s", rv.Type().Elem().Kind())

case reflect.String:
return rv.String(), nil

default:
return nil, fmt.Errorf("unsupported type: %T", v)
}
}

// UnmarshalJSON implements the json.Unmarshaler interface for Nullable, allowing it to be used as a nullable field in JSON operations.
Expand Down
97 changes: 97 additions & 0 deletions gonull_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"reflect"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -495,3 +496,99 @@ func TestValuerAndScanner(t *testing.T) {
Val: testValuerScannerStruct{},
}, scannerNullableUnsupported)
}

type customValuer struct {
value any
err error
}

func (cv customValuer) Value() (driver.Value, error) {
return cv.value, cv.err
}

func TestConvertToDriverValue(t *testing.T) {
var (
intVal int = 123
int8Val int8 = 12
int16Val int16 = 1234
int32Val int32 = 12345
int64Val int64 = 123456
uintVal uint = 123
uint8Val uint8 = 12
uint16Val uint16 = 1234
uint32Val uint32 = 12345
uint64Val uint64 = 1 << 62
float32Val float32 = 12.34
float64Val float64 = 123.456
boolVal bool = true
stringVal string = "test"
byteSlice []byte = []byte("byte slice")
ptrToInt *int = &intVal
nilPtr *int = nil
valuerSuccess customValuer = customValuer{value: "valuer value", err: nil}
valuerError customValuer = customValuer{err: errors.New("valuer error")}
unsupportedSlice = []int{1, 2, 3}
)

tests := []struct {
name string
value any
want driver.Value
wantErr bool
}{
{"Int", intVal, int64(intVal), false},
{"Int8", int8Val, int64(int8Val), false},
{"Int16", int16Val, int64(int16Val), false},
{"Int32", int32Val, int64(int32Val), false},
{"Int64", int64Val, int64(int64Val), false},
{"Uint", uintVal, int64(uintVal), false},
{"Uint8", uint8Val, int64(uint8Val), false},
{"Uint16", uint16Val, int64(uint16Val), false},
{"Uint32", uint32Val, int64(uint32Val), false},
{"Uint64", uint64Val, int64(uint64Val), false},
{"Float32", float32Val, float64(float32Val), false},
{"Float64", float64Val, float64(float64Val), false},
{"Bool", boolVal, boolVal, false},
{"String", stringVal, stringVal, false},
{"ByteSlice", byteSlice, byteSlice, false},
{"PointerToInt", ptrToInt, int64(*ptrToInt), false},
{"NilPointer", nilPtr, nil, false},
{"UnsupportedType", struct{}{}, nil, true},
{"Uint64HighBitSet", uint64(1 << 63), nil, true}, // Uint64 with high bit set
{"ValuerInterfaceSuccess", valuerSuccess, "valuer value", false},
{"ValuerInterfaceError", valuerError, nil, true},
{"UnsupportedSliceType", unsupportedSlice, nil, true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := convertToDriverValue(tt.value)
if (err != nil) != tt.wantErr {
t.Errorf("convertToDriverValue() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("convertToDriverValue() = %v, want %v", got, tt.want)
}
})
}
}

func TestNullableValue_Uint32(t *testing.T) {
uint32Val := uint32(12345)
nullableUint32 := NewNullable(uint32Val)

convertedValue, err := nullableUint32.Value()

if err != nil {
t.Fatalf("Nullable[uint32].Value() returned an error: %v", err)
}

if _, ok := convertedValue.(int64); !ok {
t.Fatalf("Nullable[uint32].Value() returned a non-int64 type: %T", convertedValue)
}

if int64(uint32Val) != convertedValue.(int64) {
t.Errorf("Nullable[uint32].Value() returned %v, want %v", convertedValue, uint32Val)
}
}