Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a timeout to RemoteAddr() to allow http.Serve in go < 1.6 work #4

Merged
merged 2 commits into from
Jul 18, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Using the library is very simple:
list, err := net.Listen("tcp", "...")
// Wrap listener in a proxyproto listener
proxyList := &proxyproto.Listener{list}
proxyList := &proxyproto.Listener{Listener: list}
conn, err :=proxyList.Accept()
...
Expand Down
41 changes: 30 additions & 11 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,24 @@ var (
// whose connections may be using the HAProxy Proxy Protocol (version 1).
// If the connection is using the protocol, the RemoteAddr() will return
// the correct client address.
//
// Optionally define ProxyHeaderTimeout to set a maximum time to
// receive the Proxy Protocol Header. Zero means no timeout.
type Listener struct {
Listener net.Listener
Listener net.Listener
ProxyHeaderTimeout time.Duration
}

// Conn is used to wrap and underlying connection which
// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
// return the address of the client instead of the proxy address.
type Conn struct {
bufReader *bufio.Reader
conn net.Conn
dstAddr *net.TCPAddr
srcAddr *net.TCPAddr
once sync.Once
bufReader *bufio.Reader
conn net.Conn
dstAddr *net.TCPAddr
srcAddr *net.TCPAddr
once sync.Once
proxyHeaderTimeout time.Duration
}

// Accept waits for and returns the next connection to the listener.
Expand All @@ -46,7 +51,7 @@ func (p *Listener) Accept() (net.Conn, error) {
if err != nil {
return nil, err
}
return NewConn(conn), nil
return NewConn(conn, p.ProxyHeaderTimeout), nil
}

// Close closes the underlying listener.
Expand All @@ -61,10 +66,11 @@ func (p *Listener) Addr() net.Addr {

// NewConn is used to wrap a net.Conn that may be speaking
// the proxy protocol into a proxyproto.Conn
func NewConn(conn net.Conn) *Conn {
func NewConn(conn net.Conn, timeout time.Duration) *Conn {
pConn := &Conn{
bufReader: bufio.NewReader(conn),
conn: conn,
bufReader: bufio.NewReader(conn),
conn: conn,
proxyHeaderTimeout: timeout,
}
return pConn
}
Expand Down Expand Up @@ -104,6 +110,8 @@ func (p *Conn) RemoteAddr() net.Addr {
p.once.Do(func() {
if err := p.checkPrefix(); err != nil && err != io.EOF {
log.Printf("[ERR] Failed to read proxy prefix: %v", err)
p.Close()
p.bufReader = bufio.NewReader(p.conn)
}
})
if p.srcAddr != nil {
Expand All @@ -125,11 +133,22 @@ func (p *Conn) SetWriteDeadline(t time.Time) error {
}

func (p *Conn) checkPrefix() error {
if p.proxyHeaderTimeout != 0 {
readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
p.conn.SetReadDeadline(readDeadLine)
defer p.conn.SetReadDeadline(time.Time{})
}

// Incrementally check each byte of the prefix
for i := 1; i <= prefixLen; i++ {
inp, err := p.bufReader.Peek(i)

if err != nil {
return err
if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
return nil
} else {
return err
}
}

// Check for a prefix mis-match, quit early
Expand Down
76 changes: 70 additions & 6 deletions protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package proxyproto

import (
"bytes"
"io"
"net"
"testing"
"time"
)

func TestPassthrough(t *testing.T) {
Expand All @@ -13,7 +13,7 @@ func TestPassthrough(t *testing.T) {
t.Fatalf("err: %v", err)
}

pl := &Listener{l}
pl := &Listener{Listener: l}

go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
Expand Down Expand Up @@ -53,13 +53,77 @@ func TestPassthrough(t *testing.T) {
}
}

func TestTimeout(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}

clientWriteDelay := 200 * time.Millisecond
proxyHeaderTimeout := 50 * time.Millisecond
pl := &Listener{Listener: l, ProxyHeaderTimeout: proxyHeaderTimeout}

go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()

// Do not send data for a while
time.Sleep(clientWriteDelay)

conn.Write([]byte("ping"))
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("pong")) {
t.Fatalf("bad: %v", recv)
}
}()

conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()

// Check the remote addr is the original 127.0.0.1
remoteAddrStartTime := time.Now()
addr := conn.RemoteAddr().(*net.TCPAddr)
if addr.IP.String() != "127.0.0.1" {
t.Fatalf("bad: %v", addr)
}
remoteAddrDuration := time.Since(remoteAddrStartTime)

// Check RemoteAddr() call did timeout
if remoteAddrDuration >= clientWriteDelay {
t.Fatalf("RemoteAddr() took longer than the specified timeout: %v < %v", proxyHeaderTimeout, remoteAddrDuration)
}

recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(recv, []byte("ping")) {
t.Fatalf("bad: %v", recv)
}

if _, err := conn.Write([]byte("pong")); err != nil {
t.Fatalf("err: %v", err)
}
}

func TestParse_ipv4(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}

pl := &Listener{l}
pl := &Listener{Listener: l}

go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
Expand Down Expand Up @@ -118,7 +182,7 @@ func TestParse_ipv6(t *testing.T) {
t.Fatalf("err: %v", err)
}

pl := &Listener{l}
pl := &Listener{Listener: l}

go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
Expand Down Expand Up @@ -177,7 +241,7 @@ func TestParse_BadHeader(t *testing.T) {
t.Fatalf("err: %v", err)
}

pl := &Listener{l}
pl := &Listener{Listener: l}

go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
Expand All @@ -194,7 +258,7 @@ func TestParse_BadHeader(t *testing.T) {

recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != io.EOF {
if err == nil {
t.Fatalf("err: %v", err)
}
}()
Expand Down