diff --git a/gonull.go b/gonull.go index e8dd993..9ea9501 100644 --- a/gonull.go +++ b/gonull.go @@ -100,21 +100,25 @@ func zeroValue[T any]() T { // convertToType is a helper function that attempts to convert the given value to type T. // This function is used by Scan to properly handle value conversion, ensuring that Nullable values are always of the correct type. func convertToType[T any](value interface{}) (T, error) { - switch v := value.(type) { - case T: - return v, nil - case int64: - // This case handles the situation when the input value is of type int64. - // It attempts to convert the int64 value to the target numeric type T if possible. - // If the conversion is successful, it returns the converted value of type T and a nil error. - // If the conversion is not possible, the function will continue to the next case (return an error). - switch t := reflect.Zero(reflect.TypeOf((*T)(nil)).Elem()).Interface().(type) { - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - if reflect.TypeOf(t).ConvertibleTo(reflect.TypeOf((*T)(nil)).Elem()) { - return reflect.ValueOf(value).Convert(reflect.TypeOf((*T)(nil)).Elem()).Interface().(T), nil - } + var zero T + if value == nil { + return zero, nil + } + + if reflect.TypeOf(value) == reflect.TypeOf(zero) { + return value.(T), nil + } + + // Check if the value is a numeric type and if T is also a numeric type. + valueType := reflect.TypeOf(value) + targetType := reflect.TypeOf(zero) + if valueType.Kind() >= reflect.Int && valueType.Kind() <= reflect.Float64 && + targetType.Kind() >= reflect.Int && targetType.Kind() <= reflect.Float64 { + if valueType.ConvertibleTo(targetType) { + convertedValue := reflect.ValueOf(value).Convert(targetType) + return convertedValue.Interface().(T), nil } } - var zero T + return zero, ErrUnsupportedConversion } diff --git a/gonull_test.go b/gonull_test.go index d4e3bed..3016033 100644 --- a/gonull_test.go +++ b/gonull_test.go @@ -1,17 +1,16 @@ -package gonull_test +package gonull import ( "database/sql/driver" "encoding/json" "testing" - "github.com/LukaGiorgadze/gonull" "github.com/stretchr/testify/assert" ) func TestNewNullable(t *testing.T) { value := "test" - n := gonull.NewNullable(value) + n := NewNullable(value) assert.True(t, n.Valid) assert.Equal(t, value, n.Val) @@ -52,7 +51,7 @@ func TestNullableScan(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var n gonull.Nullable[string] + var n Nullable[string] err := n.Scan(tt.value) if tt.wantErr { @@ -72,19 +71,19 @@ func TestNullableScan(t *testing.T) { func TestNullableValue(t *testing.T) { tests := []struct { name string - nullable gonull.Nullable[string] + nullable Nullable[string] wantValue driver.Value wantErr error }{ { name: "valid value", - nullable: gonull.NewNullable("test"), + nullable: NewNullable("test"), wantValue: "test", wantErr: nil, }, { name: "unset value", - nullable: gonull.Nullable[string]{Valid: false}, + nullable: Nullable[string]{Valid: false}, wantValue: nil, wantErr: nil, }, @@ -128,7 +127,7 @@ func TestNullableUnmarshalJSON(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - var nullable gonull.Nullable[int] + var nullable Nullable[int] err := nullable.UnmarshalJSON(tc.jsonData) @@ -143,7 +142,7 @@ func TestNullableUnmarshalJSON(t *testing.T) { func TestNullableUnmarshalJSON_Error(t *testing.T) { jsonData := []byte(`"invalid_number"`) - var nullable gonull.Nullable[int] + var nullable Nullable[int] err := nullable.UnmarshalJSON(jsonData) assert.Error(t, err) @@ -153,19 +152,19 @@ func TestNullableUnmarshalJSON_Error(t *testing.T) { func TestNullableMarshalJSON(t *testing.T) { type testCase struct { name string - nullable gonull.Nullable[int] + nullable Nullable[int] expectedJSON []byte } testCases := []testCase{ { name: "ValuePresent", - nullable: gonull.NewNullable[int](123), + nullable: NewNullable[int](123), expectedJSON: []byte(`123`), }, { name: "ValueNull", - nullable: gonull.Nullable[int]{Val: 0, Valid: false}, + nullable: Nullable[int]{Val: 0, Valid: false}, expectedJSON: []byte(`null`), }, } @@ -182,7 +181,7 @@ func TestNullableMarshalJSON(t *testing.T) { func TestNullableScan_UnconvertibleFromInt64(t *testing.T) { value := int64(123456789012345) - var n gonull.Nullable[string] + var n Nullable[string] err := n.Scan(value) assert.Error(t, err) @@ -205,7 +204,7 @@ func TestConvertToTypeFromInt64(t *testing.T) { {name: "Convert int64 to uint16", targetType: "uint16", value: int64(7), expectedError: nil}, {name: "Convert int64 to uint32", targetType: "uint32", value: int64(8), expectedError: nil}, // Add more tests as necessary - {name: "Convert int64 to string (expected to fail)", targetType: "string", value: int64(9), expectedError: gonull.ErrUnsupportedConversion}, + {name: "Convert int64 to string (expected to fail)", targetType: "string", value: int64(9), expectedError: ErrUnsupportedConversion}, } for _, tt := range tests { @@ -213,31 +212,31 @@ func TestConvertToTypeFromInt64(t *testing.T) { var err error switch tt.targetType { case "int": - n := gonull.Nullable[int]{} + n := Nullable[int]{} err = n.Scan(tt.value) case "int8": - n := gonull.Nullable[int8]{} + n := Nullable[int8]{} err = n.Scan(tt.value) case "int16": - n := gonull.Nullable[int16]{} + n := Nullable[int16]{} err = n.Scan(tt.value) case "int32": - n := gonull.Nullable[int32]{} + n := Nullable[int32]{} err = n.Scan(tt.value) case "uint": - n := gonull.Nullable[uint]{} + n := Nullable[uint]{} err = n.Scan(tt.value) case "uint8": - n := gonull.Nullable[uint8]{} + n := Nullable[uint8]{} err = n.Scan(tt.value) case "uint16": - n := gonull.Nullable[uint16]{} + n := Nullable[uint16]{} err = n.Scan(tt.value) case "uint32": - n := gonull.Nullable[uint32]{} + n := Nullable[uint32]{} err = n.Scan(tt.value) case "string": - n := gonull.Nullable[string]{} + n := Nullable[string]{} err = n.Scan(tt.value) default: t.Fatalf("Unsupported type: %s", tt.targetType) @@ -263,7 +262,7 @@ func TestNullableScanWithCustomEnum(t *testing.T) { type TestModel struct { ID int - Field gonull.Nullable[TestEnum] + Field Nullable[TestEnum] } // Simulate the scenario where the SQL driver returns an int64 @@ -273,19 +272,121 @@ func TestNullableScanWithCustomEnum(t *testing.T) { // The converted value 0 (as float32) matches TestEnumA, which is also 0 when converted to float32. sqlReturnedValue := int64(0) - model := TestModel{ID: 1, Field: gonull.NewNullable(TestEnumA)} + model := TestModel{ID: 1, Field: NewNullable(TestEnumA)} err := model.Field.Scan(sqlReturnedValue) - if err != nil { - assert.Error(t, err, "Scan failed with unsupported type conversion") - } else { - assert.Equal(t, TestEnumA, model.Field.Val, "Scanned value does not match expected enum value") + assert.NoError(t, err, "Scan failed with unsupported type conversion") + assert.Equal(t, TestEnumA, model.Field.Val, "Scanned value does not match expected enum value") + +} + +func TestConvertToTypeWithNilValue(t *testing.T) { + tests := []struct { + name string + expected interface{} + }{ + { + name: "Nil to int", + expected: int(0), + }, + { + name: "Nil to int8", + expected: int8(0), + }, + { + name: "Nil to int16", + expected: int16(0), + }, + { + name: "Nil to int32", + expected: int32(0), + }, + { + name: "Nil to int64", + expected: int64(0), + }, + { + name: "Nil to uint", + expected: uint(0), + }, + { + name: "Nil to uint8 (byte)", + expected: uint8(0), + }, + { + name: "Nil to uint16", + expected: uint16(0), + }, + { + name: "Nil to uint32", + expected: uint32(0), + }, + { + name: "Nil to uint64", + expected: uint64(0), + }, + { + name: "Nil to float32", + expected: float32(0), + }, + { + name: "Nil to float64", + expected: float64(0), + }, + { + name: "Nil to bool", + expected: bool(false), + }, + { + name: "Nil to string", + expected: "", + }, } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var result interface{} + var err error + + switch tc.expected.(type) { + case int: + result, err = convertToType[int](nil) + case int8: + result, err = convertToType[int8](nil) + case int16: + result, err = convertToType[int16](nil) + case int32: + result, err = convertToType[int32](nil) + case int64: + result, err = convertToType[int64](nil) + case uint: + result, err = convertToType[uint](nil) + case uint8: + result, err = convertToType[uint8](nil) + case uint16: + result, err = convertToType[uint16](nil) + case uint32: + result, err = convertToType[uint32](nil) + case uint64: + result, err = convertToType[uint64](nil) + case float32: + result, err = convertToType[float32](nil) + case float64: + result, err = convertToType[float64](nil) + case bool: + result, err = convertToType[bool](nil) + case string: + result, err = convertToType[string](nil) + } + + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + }) + } } type testStruct struct { - Foo gonull.Nullable[*string] `json:"foo"` + Foo Nullable[*string] `json:"foo"` } func TestPresent(t *testing.T) {