diff --git a/README.md b/README.md index 25a779c..47e9718 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ Using the library is very simple: list, err := net.Listen("tcp", "...") // Wrap listener in a proxyproto listener -proxyList := &proxyproto.Listener{list} +proxyList := &proxyproto.Listener{Listener: list} conn, err :=proxyList.Accept() ... diff --git a/protocol.go b/protocol.go index 2fc1dfc..dfa6eb0 100644 --- a/protocol.go +++ b/protocol.go @@ -24,19 +24,24 @@ var ( // whose connections may be using the HAProxy Proxy Protocol (version 1). // If the connection is using the protocol, the RemoteAddr() will return // the correct client address. +// +// Optionally define ProxyHeaderTimeout to set a maximum time to +// receive the Proxy Protocol Header. Zero means no timeout. type Listener struct { - Listener net.Listener + Listener net.Listener + ProxyHeaderTimeout time.Duration } // Conn is used to wrap and underlying connection which // may be speaking the Proxy Protocol. If it is, the RemoteAddr() will // return the address of the client instead of the proxy address. type Conn struct { - bufReader *bufio.Reader - conn net.Conn - dstAddr *net.TCPAddr - srcAddr *net.TCPAddr - once sync.Once + bufReader *bufio.Reader + conn net.Conn + dstAddr *net.TCPAddr + srcAddr *net.TCPAddr + once sync.Once + proxyHeaderTimeout time.Duration } // Accept waits for and returns the next connection to the listener. @@ -46,7 +51,7 @@ func (p *Listener) Accept() (net.Conn, error) { if err != nil { return nil, err } - return NewConn(conn), nil + return NewConn(conn, p.ProxyHeaderTimeout), nil } // Close closes the underlying listener. @@ -61,10 +66,11 @@ func (p *Listener) Addr() net.Addr { // NewConn is used to wrap a net.Conn that may be speaking // the proxy protocol into a proxyproto.Conn -func NewConn(conn net.Conn) *Conn { +func NewConn(conn net.Conn, timeout time.Duration) *Conn { pConn := &Conn{ - bufReader: bufio.NewReader(conn), - conn: conn, + bufReader: bufio.NewReader(conn), + conn: conn, + proxyHeaderTimeout: timeout, } return pConn } @@ -104,6 +110,8 @@ func (p *Conn) RemoteAddr() net.Addr { p.once.Do(func() { if err := p.checkPrefix(); err != nil && err != io.EOF { log.Printf("[ERR] Failed to read proxy prefix: %v", err) + p.Close() + p.bufReader = bufio.NewReader(p.conn) } }) if p.srcAddr != nil { @@ -125,11 +133,22 @@ func (p *Conn) SetWriteDeadline(t time.Time) error { } func (p *Conn) checkPrefix() error { + if p.proxyHeaderTimeout != 0 { + readDeadLine := time.Now().Add(p.proxyHeaderTimeout) + p.conn.SetReadDeadline(readDeadLine) + defer p.conn.SetReadDeadline(time.Time{}) + } + // Incrementally check each byte of the prefix for i := 1; i <= prefixLen; i++ { inp, err := p.bufReader.Peek(i) + if err != nil { - return err + if neterr, ok := err.(net.Error); ok && neterr.Timeout() { + return nil + } else { + return err + } } // Check for a prefix mis-match, quit early diff --git a/protocol_test.go b/protocol_test.go index ba70ee9..ab29731 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -2,9 +2,9 @@ package proxyproto import ( "bytes" - "io" "net" "testing" + "time" ) func TestPassthrough(t *testing.T) { @@ -13,7 +13,7 @@ func TestPassthrough(t *testing.T) { t.Fatalf("err: %v", err) } - pl := &Listener{l} + pl := &Listener{Listener: l} go func() { conn, err := net.Dial("tcp", pl.Addr().String()) @@ -53,13 +53,77 @@ func TestPassthrough(t *testing.T) { } } +func TestTimeout(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } + + clientWriteDelay := 200 * time.Millisecond + proxyHeaderTimeout := 50 * time.Millisecond + pl := &Listener{Listener: l, ProxyHeaderTimeout: proxyHeaderTimeout} + + go func() { + conn, err := net.Dial("tcp", pl.Addr().String()) + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + // Do not send data for a while + time.Sleep(clientWriteDelay) + + conn.Write([]byte("ping")) + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err != nil { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(recv, []byte("pong")) { + t.Fatalf("bad: %v", recv) + } + }() + + conn, err := pl.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + // Check the remote addr is the original 127.0.0.1 + remoteAddrStartTime := time.Now() + addr := conn.RemoteAddr().(*net.TCPAddr) + if addr.IP.String() != "127.0.0.1" { + t.Fatalf("bad: %v", addr) + } + remoteAddrDuration := time.Since(remoteAddrStartTime) + + // Check RemoteAddr() call did timeout + if remoteAddrDuration >= clientWriteDelay { + t.Fatalf("RemoteAddr() took longer than the specified timeout: %v < %v", proxyHeaderTimeout, remoteAddrDuration) + } + + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err != nil { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(recv, []byte("ping")) { + t.Fatalf("bad: %v", recv) + } + + if _, err := conn.Write([]byte("pong")); err != nil { + t.Fatalf("err: %v", err) + } +} + func TestParse_ipv4(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } - pl := &Listener{l} + pl := &Listener{Listener: l} go func() { conn, err := net.Dial("tcp", pl.Addr().String()) @@ -118,7 +182,7 @@ func TestParse_ipv6(t *testing.T) { t.Fatalf("err: %v", err) } - pl := &Listener{l} + pl := &Listener{Listener: l} go func() { conn, err := net.Dial("tcp", pl.Addr().String()) @@ -177,7 +241,7 @@ func TestParse_BadHeader(t *testing.T) { t.Fatalf("err: %v", err) } - pl := &Listener{l} + pl := &Listener{Listener: l} go func() { conn, err := net.Dial("tcp", pl.Addr().String()) @@ -194,7 +258,7 @@ func TestParse_BadHeader(t *testing.T) { recv := make([]byte, 4) _, err = conn.Read(recv) - if err != io.EOF { + if err == nil { t.Fatalf("err: %v", err) } }()