Skip to content

Commit

Permalink
add safe cast for uint
Browse files Browse the repository at this point in the history
  • Loading branch information
moniliu committed Jun 2, 2023
1 parent dbb6117 commit 78f39d6
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 6 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect
golang.org/x/mod v0.8.0 // indirect
golang.org/x/sys v0.5.0 // indirect
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsr
github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc=
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY=
golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
Expand Down
21 changes: 15 additions & 6 deletions graphql/uint.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"fmt"
"io"
"strconv"

"golang.org/x/exp/constraints"
)

func MarshalUint(i uint) Marshaler {
Expand All @@ -19,9 +21,9 @@ func UnmarshalUint(v interface{}) (uint, error) {
u64, err := strconv.ParseUint(v, 10, 64)
return uint(u64), err
case int:
return uint(v), nil
return safeUintCast[int, uint](v)
case int64:
return uint(v), nil
return safeUintCast[int64, uint](v)
case json.Number:
u64, err := strconv.ParseUint(string(v), 10, 64)
return uint(u64), err
Expand All @@ -41,9 +43,9 @@ func UnmarshalUint64(v interface{}) (uint64, error) {
case string:
return strconv.ParseUint(v, 10, 64)
case int:
return uint64(v), nil
return safeUintCast[int, uint64](v)
case int64:
return uint64(v), nil
return safeUintCast[int64, uint64](v)
case json.Number:
return strconv.ParseUint(string(v), 10, 64)
default:
Expand All @@ -66,9 +68,9 @@ func UnmarshalUint32(v interface{}) (uint32, error) {
}
return uint32(iv), nil
case int:
return uint32(v), nil
return safeUintCast[int, uint32](v)
case int64:
return uint32(v), nil
return safeUintCast[int64, uint32](v)
case json.Number:
iv, err := strconv.ParseUint(string(v), 10, 32)
if err != nil {
Expand All @@ -79,3 +81,10 @@ func UnmarshalUint32(v interface{}) (uint32, error) {
return 0, fmt.Errorf("%T is not an uint", v)
}
}

func safeUintCast[F constraints.Signed, T constraints.Unsigned](f F) (T, error) {
if f < 0 {
return 0, fmt.Errorf("cannot cast %d to uint", f)
}
return T(f), nil
}
36 changes: 36 additions & 0 deletions graphql/uint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ func TestUint(t *testing.T) {
assert.Equal(t, uint(123), mustUnmarshalUint(int64(123)))
assert.Equal(t, uint(123), mustUnmarshalUint(json.Number("123")))
assert.Equal(t, uint(123), mustUnmarshalUint("123"))
assert.NotNil(t, mustFailUnmarshalUint(-2))
assert.NotNil(t, mustFailUnmarshalUint(int64(-123)))
assert.NotNil(t, mustFailUnmarshalUint("-4294967295"))
assert.NotNil(t, mustFailUnmarshalUint(json.Number("-123")))
})
}

Expand All @@ -29,6 +33,14 @@ func mustUnmarshalUint(v interface{}) uint {
return res
}

func mustFailUnmarshalUint(v interface{}) error {
_, err := UnmarshalUint(v)
if err == nil {
panic(err)
}
return err
}

func TestUint32(t *testing.T) {
t.Run("marshal", func(t *testing.T) {
assert.Equal(t, "123", m2s(MarshalUint32(123)))
Expand All @@ -41,6 +53,10 @@ func TestUint32(t *testing.T) {
assert.Equal(t, uint32(123), mustUnmarshalUint32(json.Number("123")))
assert.Equal(t, uint32(123), mustUnmarshalUint32("123"))
assert.Equal(t, uint32(4294967295), mustUnmarshalUint32("4294967295"))
assert.NotNil(t, mustFailUnmarshalUint32(-2))
assert.NotNil(t, mustFailUnmarshalUint32(int64(-123)))
assert.NotNil(t, mustFailUnmarshalUint32("-4294967295"))
assert.NotNil(t, mustFailUnmarshalUint32(json.Number("-123")))
})
}

Expand All @@ -52,6 +68,14 @@ func mustUnmarshalUint32(v interface{}) uint32 {
return res
}

func mustFailUnmarshalUint32(v interface{}) error {
_, err := UnmarshalUint32(v)
if err == nil {
panic(err)
}
return err
}

func TestUint64(t *testing.T) {
t.Run("marshal", func(t *testing.T) {
assert.Equal(t, "123", m2s(MarshalUint64(123)))
Expand All @@ -62,6 +86,10 @@ func TestUint64(t *testing.T) {
assert.Equal(t, uint64(123), mustUnmarshalUint64(int64(123)))
assert.Equal(t, uint64(123), mustUnmarshalUint64(json.Number("123")))
assert.Equal(t, uint64(123), mustUnmarshalUint64("123"))
assert.NotNil(t, mustFailUnmarshalUint64(-2))
assert.NotNil(t, mustFailUnmarshalUint64(int64(-123)))
assert.NotNil(t, mustFailUnmarshalUint64("-4294967295"))
assert.NotNil(t, mustFailUnmarshalUint64(json.Number("-123")))
})
}

Expand All @@ -72,3 +100,11 @@ func mustUnmarshalUint64(v interface{}) uint64 {
}
return res
}

func mustFailUnmarshalUint64(v interface{}) error {
_, err := UnmarshalUint64(v)
if err == nil {
panic(err)
}
return err
}

0 comments on commit 78f39d6

Please sign in to comment.