Skip to content

Commit

Permalink
rework timeout reader
Browse files Browse the repository at this point in the history
  • Loading branch information
andyollylarkin committed Aug 25, 2023
1 parent 3e204ee commit a583535
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 37 deletions.
3 changes: 2 additions & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package watermillnet
import (
"errors"
"fmt"
"os"
)

type InvalidConfigError struct {
Expand All @@ -19,6 +20,6 @@ var (
ErrSubscriberClosed = errors.New("subscriber closed")
ErrSubscriberNotStarted = errors.New("subscriber not started")
ErrNacked = errors.New("remote side sent nack for message")
ErrIOTimeout = errors.New("i/o timeout")
ErrIOTimeout = os.ErrDeadlineExceeded
ErrConnectionNotSet = errors.New("connection not set")
)
2 changes: 1 addition & 1 deletion internal/reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,6 @@ func TestReaderErrorTimeoutButFullRead(t *testing.T) {
r := internal.NewTimeoutReader(&tr, time.Duration(0))

content, err := r.ReadBytes(delimiter)
require.ErrorIs(t, err, io.EOF)
require.NoError(t, err)
assert.Equal(t, "Hello", string(content[:len(content)-1]))
}
73 changes: 39 additions & 34 deletions internal/timeoutReader.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package internal

import (
"bufio"
"bytes"
"net"
"strings"
"errors"
"os"
"time"
)

type TimeoutReader struct {
bufioReader bufio.Reader
r ReadDeadliner
extendReadDeadline time.Duration
}
Expand All @@ -16,61 +18,64 @@ type TimeoutReader struct {
// If another error occurs reader return that error.
func NewTimeoutReader(r ReadDeadliner, extendReadDeadline time.Duration) *TimeoutReader {
tr := new(TimeoutReader)
tr.bufioReader = *bufio.NewReader(r)
tr.r = r
tr.extendReadDeadline = extendReadDeadline

return tr
}

func (tr *TimeoutReader) ReadBytes(delim byte) ([]byte, error) {
var out bytes.Buffer

var bufSize int = 1

buf := make([]byte, bufSize)
var outBuf bytes.Buffer

for {
n, err := tr.r.Read(buf)

// if has error and no read bytes
if err != nil && n == 0 {
return out.Bytes(), err
}
readed, err := tr.bufioReader.ReadBytes(LenDelimiter)
outBuf.Write(readed)

if err != nil && n > 0 {
e, ok := err.(*net.OpError)
if ok && strings.Contains(e.Error(), "i/o timeout") {
if err != nil {
if isTimeoutErr(err) && len(readed) > 0 {
tr.r.SetReadDeadline(time.Now().Add(tr.extendReadDeadline))
out.Write(buf)
cleanSlice(buf)

continue
} else {
out.Write(buf)

return out.Bytes(), e
return outBuf.Bytes(), err
}
} else {
break
}
}

pos := bytes.IndexByte(buf, delim)
return outBuf.Bytes(), nil
}

if pos == -1 {
out.Write(buf)
cleanSlice(buf)
func isTimeoutErr(err error) bool {
return errors.Is(err, os.ErrDeadlineExceeded)
}

continue
} else {
out.Write(buf[0 : pos+1])
func (tr *TimeoutReader) Read(p []byte) (int, error) {
var outBuf bytes.Buffer

for {
tmpBuf := make([]byte, len(p))
n, err := tr.bufioReader.Read(tmpBuf)
outBuf.Write(tmpBuf)

if err != nil {
if isTimeoutErr(err) && n > 0 {
tr.r.SetReadDeadline(time.Now().Add(tr.extendReadDeadline))

continue
} else {
copy(p, outBuf.Bytes())

return n, err
}
} else {
break
}
}
copy(p, outBuf.Bytes())

return out.Bytes(), nil
}

func (tr *TimeoutReader) Read(p []byte) (int, error) {
return tr.r.Read(p)
return outBuf.Len(), nil
}

func cleanSlice(src []byte) {
Expand Down
13 changes: 12 additions & 1 deletion subscriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
"github.com/andyollylarkin/watermill-net/internal"
)

const maxBodyLen = 20 << 30

type sub struct {
Topic string
MsgChan chan *message.Message // client message chan
Expand Down Expand Up @@ -314,7 +316,7 @@ func (s *Subscriber) sendAck(ack bool, uuid string) error {
}

func (s *Subscriber) readContent() {
var readTimeout = time.Second * 3
var readTimeout = time.Second * 5

for {
select {
Expand All @@ -324,6 +326,7 @@ func (s *Subscriber) readContent() {
s.mu.RLock()
s.conn.SetReadDeadline(time.Now().Add(readTimeout))
r := internal.NewTimeoutReader(s.conn, readTimeout)
// r := bufio.NewReader(s.conn)

lenRaw, err := r.ReadBytes(internal.LenDelimiter)
if err != nil {
Expand All @@ -337,6 +340,14 @@ func (s *Subscriber) readContent() {
}

readLen := internal.ReadLen(lenRaw[:len(lenRaw)-1]) // trim len delimiter
if readLen <= 0 || readLen > maxBodyLen {
if s.logger != nil {
s.logger.Info("Message body out of range", watermill.LogFields{"len": readLen})
}

continue
}

lr := io.LimitReader(r, int64(readLen))

respBody := make([]byte, readLen)
Expand Down

0 comments on commit a583535

Please sign in to comment.