Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check for proxy prefix early, so RemoteAddr doesn't block Accept #2

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 58 additions & 52 deletions protocol.go
Original file line number Diff line number Diff line change
@@ -1,37 +1,42 @@
// Package proxoproto implements a net.Listener supporting HAProxy PROTO protocol.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo in package name

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no intention of fixing this. You can take this broken PR and do with it as you please.

Or take this PR and fix the typo later.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, just directing this at @armon when he merges it. Thanks for your help.

//
// See http://www.haproxy.org/download/1.5/doc/proxy-protocol.txt for details.
package proxyproto

import (
"bufio"
"bytes"
"fmt"
"io"
"log"
"net"
"strconv"
"strings"
"sync"
"time"
)

var (
// prefix is the string we look for at the start of a connection
// to check if this connection is using the proxy protocol
prefix = []byte("PROXY ")
prefixLen = len(prefix)
)
// prefix is the string we look for at the start of a connection
// to check if this connection is using the proxy protocol.
var prefix = []byte("PROXY ")

// Listener is used to wrap an underlying listener,
// whose connections may be using the HAProxy Proxy Protocol (version 1).
// If the connection is using the protocol, the RemoteAddr() will return
// the correct client address.
// Listener wraps an underlying listener whose connections may be
// using the HAProxy Proxy Protocol (version 1).
// If the connection is using the protocol, RemoteAddr will return the
// correct client address.
type Listener struct {
Listener net.Listener

initOnce sync.Once // guards init, which sets the following:
connc chan interface{} // *conn or error
}

// Conn is used to wrap and underlying connection which
func (p *Listener) init() {
p.connc = make(chan interface{})
}

// conn is used to wrap and underlying connection which
// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
// return the address of the client instead of the proxy address.
type Conn struct {
type conn struct {
bufReader *bufio.Reader
conn net.Conn
dstAddr *net.TCPAddr
Expand All @@ -41,38 +46,49 @@ type Conn struct {

// Accept waits for and returns the next connection to the listener.
func (p *Listener) Accept() (net.Conn, error) {
p.initOnce.Do(p.init)
// Get the underlying connection
conn, err := p.Listener.Accept()
rawc, err := p.Listener.Accept()
if err != nil {
return nil, err
}
return NewConn(conn), nil
go func() {
c, err := newConn(rawc)
if err != nil {
p.connc <- err
} else {
p.connc <- c
}
}()
v := <-p.connc
if c, ok := v.(*conn); ok {
return c, nil
} else {
return nil, v.(error)
}
}

// Close closes the underlying listener.
func (p *Listener) Close() error {
return p.Listener.Close()
}
func (p *Listener) Close() error { return p.Listener.Close() }

// Addr returns the underlying listener's network address.
func (p *Listener) Addr() net.Addr {
return p.Listener.Addr()
}
func (p *Listener) Addr() net.Addr { return p.Listener.Addr() }

// NewConn is used to wrap a net.Conn that may be speaking
// the proxy protocol into a proxyproto.Conn
func NewConn(conn net.Conn) *Conn {
pConn := &Conn{
bufReader: bufio.NewReader(conn),
conn: conn,
func newConn(c net.Conn) (*conn, error) {
pc := &conn{
bufReader: bufio.NewReader(c),
conn: c,
}
return pConn
if err := pc.checkPrefix(); err != nil {
return nil, err
}
return pc, nil
}

// Read is check for the proxy protocol header when doing
// the initial scan. If there is an error parsing the header,
// it is returned and the socket is closed.
func (p *Conn) Read(b []byte) (int, error) {
func (p *conn) Read(b []byte) (int, error) {
var err error
p.once.Do(func() { err = p.checkPrefix() })
if err != nil {
Expand All @@ -81,15 +97,15 @@ func (p *Conn) Read(b []byte) (int, error) {
return p.bufReader.Read(b)
}

func (p *Conn) Write(b []byte) (int, error) {
func (p *conn) Write(b []byte) (int, error) {
return p.conn.Write(b)
}

func (p *Conn) Close() error {
func (p *conn) Close() error {
return p.conn.Close()
}

func (p *Conn) LocalAddr() net.Addr {
func (p *conn) LocalAddr() net.Addr {
return p.conn.LocalAddr()
}

Expand All @@ -100,35 +116,25 @@ func (p *Conn) LocalAddr() net.Addr {
// Once implication of this is that the call could block if the
// client is slow. Using a Deadline is recommended if this is called
// before Read()
func (p *Conn) RemoteAddr() net.Addr {
p.once.Do(func() {
if err := p.checkPrefix(); err != nil && err != io.EOF {
log.Printf("[ERR] Failed to read proxy prefix: %v", err)
}
})
func (p *conn) RemoteAddr() net.Addr {
if p.srcAddr != nil {
return p.srcAddr
}
return p.conn.RemoteAddr()
}

func (p *Conn) SetDeadline(t time.Time) error {
return p.conn.SetDeadline(t)
}

func (p *Conn) SetReadDeadline(t time.Time) error {
return p.conn.SetReadDeadline(t)
}

func (p *Conn) SetWriteDeadline(t time.Time) error {
return p.conn.SetWriteDeadline(t)
}
func (p *conn) SetDeadline(t time.Time) error { return p.conn.SetDeadline(t) }
func (p *conn) SetReadDeadline(t time.Time) error { return p.conn.SetReadDeadline(t) }
func (p *conn) SetWriteDeadline(t time.Time) error { return p.conn.SetWriteDeadline(t) }

func (p *Conn) checkPrefix() error {
func (p *conn) checkPrefix() error {
// Incrementally check each byte of the prefix
for i := 1; i <= prefixLen; i++ {
for i := 1; i <= len(prefix); i++ {
inp, err := p.bufReader.Peek(i)
if err != nil {
// TOOD: this isn't right. it returns EOF on
// non-PROXY connections sending payloads
// shorter than len(prefix).
return err
}

Expand Down
16 changes: 8 additions & 8 deletions protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ func TestPassthrough(t *testing.T) {
if err != nil {
t.Fatalf("err: %v", err)
}

pl := &Listener{l}
pl := &Listener{Listener: l}
defer pl.Close()

go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
Expand Down Expand Up @@ -58,8 +58,8 @@ func TestParse_ipv4(t *testing.T) {
if err != nil {
t.Fatalf("err: %v", err)
}

pl := &Listener{l}
pl := &Listener{Listener: l}
defer pl.Close()

go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
Expand Down Expand Up @@ -117,8 +117,8 @@ func TestParse_ipv6(t *testing.T) {
if err != nil {
t.Fatalf("err: %v", err)
}

pl := &Listener{l}
pl := &Listener{Listener: l}
defer pl.Close()

go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
Expand Down Expand Up @@ -176,8 +176,8 @@ func TestParse_BadHeader(t *testing.T) {
if err != nil {
t.Fatalf("err: %v", err)
}

pl := &Listener{l}
pl := &Listener{Listener: l}
defer pl.Close()

go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
Expand Down