Skip to content

Commit

Permalink
aghio: create package
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Nov 19, 2020
1 parent 8a9c6e8 commit 7b77c2c
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 0 deletions.
14 changes: 14 additions & 0 deletions internal/aghio/limitedbody.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package aghio

import (
"net/http"
)

// LimitRequestBody substitutes body of the request with LimitedReadCloser.
func LimitRequestBody(h http.Handler) (limited http.Handler) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Body = LimitReadCloser(r.Body, 1024)

h.ServeHTTP(w, r)
})
}
48 changes: 48 additions & 0 deletions internal/aghio/limitedreader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Package aghio contains extensions for io package's types and methods
package aghio

import (
"fmt"
"io"

"github.com/AdguardTeam/AdGuardHome/internal/agherr"
)

// ErrLimitReached is returned if the limit of LimitedReader is reached.
const ErrLimitReached agherr.Error = "limit reached"

// limitedReadCloser is a wrapper for io.ReadCloser with limited reader and
// dealing with agherr package.
type limitedReadCloser struct {
limit int64
N int64
io.ReadCloser
}

// Read implements Reader interface.
func (lrc *limitedReadCloser) Read(p []byte) (n int, err error) {
if lrc.N <= 0 {
return 0, fmt.Errorf("read %d bytes: %w", lrc.limit, ErrLimitReached)
}
if int64(len(p)) > lrc.N {
p = p[0:lrc.N]
}
n, err = lrc.ReadCloser.Read(p)
lrc.N -= int64(n)
return n, err
}

// Close implements Closer interface.
func (lrc *limitedReadCloser) Close() error {
return lrc.ReadCloser.Close()
}

// LimitReadCloser returns a ReadCloser with original Closer and Reader that
// stops with ErrLimitReached after n bytes read.
func LimitReadCloser(rc io.ReadCloser, n int64) io.ReadCloser {
return &limitedReadCloser{
limit: n,
N: n,
ReadCloser: rc,
}
}
58 changes: 58 additions & 0 deletions internal/aghio/limitedreader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package aghio

import (
"errors"
"io"
"io/ioutil"
"strings"
"testing"

"github.com/stretchr/testify/assert"
)

func TestLimitedReadCloser_Read(t *testing.T) {
testCases := []struct {
name string
limit int64
rStr string
want int
err error
}{{
name: "perfectly_match",
limit: 3,
rStr: "abc",
want: 3,
err: nil,
}, {
name: "eof",
limit: 3,
rStr: "",
want: 0,
err: io.EOF,
}, {
name: "limit_reached",
limit: 0,
rStr: "abc",
want: 0,
err: ErrLimitReached,
}, {
name: "truncated",
limit: 2,
rStr: "abc",
want: 2,
err: nil,
}}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
readCloser := ioutil.NopCloser(strings.NewReader(tc.rStr))
buf := make([]byte, tc.limit+1)

lreader := LimitReadCloser(readCloser, tc.limit)
n, err := lreader.Read(buf)

assert.Equal(t, n, tc.want)
assert.True(t, errors.Is(err, tc.err), buf)
})
}
}

0 comments on commit 7b77c2c

Please sign in to comment.