diff --git a/docs/content/reference/scalars.md b/docs/content/reference/scalars.md index f35a822527..8241090b6a 100644 --- a/docs/content/reference/scalars.md +++ b/docs/content/reference/scalars.md @@ -35,7 +35,7 @@ Maps a `Upload` GraphQL scalar to a `graphql.Upload` struct, defined as follows: ```go type Upload struct { - File io.Reader + File io.ReadSeeker Filename string Size int64 ContentType string diff --git a/graphql/handler/transport/http_form.go b/graphql/handler/transport/http_form.go index 50eef20030..3d3477b9ba 100644 --- a/graphql/handler/transport/http_form.go +++ b/graphql/handler/transport/http_form.go @@ -131,7 +131,7 @@ func (f MultipartForm) Do(w http.ResponseWriter, r *http.Request, exec graphql.G } for _, path := range paths { upload = graphql.Upload{ - File: &bytesReader{s: &fileBytes, i: 0, prevRune: -1}, + File: &bytesReader{s: &fileBytes, i: 0}, Size: int64(len(fileBytes)), Filename: filename, ContentType: contentType, diff --git a/graphql/handler/transport/reader.go b/graphql/handler/transport/reader.go index d3261e2833..c58eccf33f 100644 --- a/graphql/handler/transport/reader.go +++ b/graphql/handler/transport/reader.go @@ -6,9 +6,8 @@ import ( ) type bytesReader struct { - s *[]byte - i int64 // current reading index - prevRune int // index of previous rune; or < 0 + s *[]byte + i int64 // current reading index } func (r *bytesReader) Read(b []byte) (n int, err error) { @@ -18,8 +17,29 @@ func (r *bytesReader) Read(b []byte) (n int, err error) { if r.i >= int64(len(*r.s)) { return 0, io.EOF } - r.prevRune = -1 n = copy(b, (*r.s)[r.i:]) r.i += int64(n) return } + +func (r *bytesReader) Seek(offset int64, whence int) (int64, error) { + if r.s == nil { + return 0, errors.New("byte slice pointer is nil") + } + var abs int64 + switch whence { + case io.SeekStart: + abs = offset + case io.SeekCurrent: + abs = r.i + offset + case io.SeekEnd: + abs = int64(len(*r.s)) + offset + default: + return 0, errors.New("invalid whence") + } + if abs < 0 { + return 0, errors.New("negative position") + } + r.i = abs + return abs, nil +} diff --git a/graphql/handler/transport/reader_test.go b/graphql/handler/transport/reader_test.go index eaff3b2b81..e57ac5a14f 100644 --- a/graphql/handler/transport/reader_test.go +++ b/graphql/handler/transport/reader_test.go @@ -82,4 +82,46 @@ func TestBytesRead(t *testing.T) { } require.Equal(t, "0193456789", string(got)) }) + + t.Run("read using buffer multiple times", func(t *testing.T) { + data := []byte("0123456789") + r := bytesReader{s: &data} + + got := make([]byte, 0, 11) + buf := make([]byte, 1) + for { + n, err := r.Read(buf) + if n < 0 { + require.Fail(t, "unexpected bytes read size") + } + got = append(got, buf[:n]...) + if err != nil { + if err == io.EOF { + break + } + require.Fail(t, "unexpected error while reading", err.Error()) + } + } + require.Equal(t, "0123456789", string(got)) + + pos, err := r.Seek(0, io.SeekStart) + require.NoError(t, err) + require.Equal(t, int64(0), pos) + + got = make([]byte, 0, 11) + for { + n, err := r.Read(buf) + if n < 0 { + require.Fail(t, "unexpected bytes read size") + } + got = append(got, buf[:n]...) + if err != nil { + if err == io.EOF { + break + } + require.Fail(t, "unexpected error while reading", err.Error()) + } + } + require.Equal(t, "0123456789", string(got)) + }) } diff --git a/graphql/upload.go b/graphql/upload.go index 62f71c0dc0..dafbde6508 100644 --- a/graphql/upload.go +++ b/graphql/upload.go @@ -6,7 +6,7 @@ import ( ) type Upload struct { - File io.Reader + File io.ReadSeeker Filename string Size int64 ContentType string