Skip to content

Commit

Permalink
Merge pull request #624 from rgooch/master
Browse files Browse the repository at this point in the history
Fix connection leak and log spamming in lib/net/reverseconnection.
  • Loading branch information
rgooch committed Jul 4, 2019
2 parents 45f045f + ba2d9d6 commit a0802cd
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 36 deletions.
8 changes: 2 additions & 6 deletions lib/net/reverseconnection/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ import (
)

type acceptEvent struct {
conn *Conn
conn *listenerConn
error error
}

type Conn struct {
type listenerConn struct {
libnet.TCPConn
listener *Listener
}
Expand Down Expand Up @@ -50,10 +50,6 @@ type ReverseListenerConfig struct {
MaximumInterval time.Duration // Maximum interval to request connections.
}

func (conn *Conn) Close() error {
return conn.close()
}

// Listen creates a listener which may be used to accept incoming connections.
// It listens on all available IP addresses on the local system.
func Listen(network string, portNumber uint, logger log.DebugLogger) (
Expand Down
33 changes: 22 additions & 11 deletions lib/net/reverseconnection/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,13 @@ func (d *Dialer) lookup(address string) net.Conn {
}

func (d *Dialer) connectHandler(w http.ResponseWriter, req *http.Request) {
d.logger.Debugf(1, "%s request from remote: %s\n",
req.Method, req.RemoteAddr)
if req.Method != "CONNECT" {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusMethodNotAllowed)
d.logger.Debugf(0, "rejecting method=%s from remote: %s\n",
req.Method, req.RemoteAddr)
return
}
hijacker, ok := w.(http.Hijacker)
Expand All @@ -118,26 +122,30 @@ func (d *Dialer) connectHandler(w http.ResponseWriter, req *http.Request) {
d.logger.Println("not a hijacker ", req.RemoteAddr)
return
}
d.connectionMapLock.Lock()
if conn, ok := d.connectionMap[req.RemoteAddr]; ok {
// We have nothing to detect if the remote closed, so assume the remote
// is retrying and close the old (unused) connection.
delete(d.connectionMap, req.RemoteAddr)
d.connectionMapLock.Unlock()
conn.Close()
d.logger.Debugf(0, "closed unused duplicate remote: %s\n",
req.RemoteAddr)
} else {
d.connectionMapLock.Unlock()
}
conn, _, err := hijacker.Hijack()
if err != nil {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusInternalServerError)
d.logger.Println("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
d.logger.Printf("rpc hijacking %s: %s\n", req.RemoteAddr, err)
return
}
defer func() {
if conn != nil {
conn.Close()
}
}()
d.connectionMapLock.Lock()
_, ok = d.connectionMap[req.RemoteAddr]
d.connectionMapLock.Unlock()
if ok {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusAlreadyReported)
return
}
_, err = io.WriteString(conn, "HTTP/1.0 "+connectString+"\n\n")
if err != nil {
d.logger.Println("error writing connect message: ", err.Error())
Expand All @@ -153,11 +161,14 @@ func (d *Dialer) connectHandler(w http.ResponseWriter, req *http.Request) {
d.logger.Printf("error writing ReverseDialerMessage: %s\n", err)
return
}
// Ensure we don't write anything else until the other end has drained it's
// Ensure we don't write anything else until the other end has drained its
// buffer.
buffer := make([]byte, 1)
d.logger.Debugf(1, "waiting for sync byte from remote: %s\n",
req.RemoteAddr)
if _, err := conn.Read(buffer); err != nil {
d.logger.Printf("error reading sync byte: %s\n", err)
d.logger.Printf("error reading sync byte from: %s: %s\n",
req.RemoteAddr, err)
return
}
if d.add(req.RemoteAddr, conn) {
Expand Down
57 changes: 43 additions & 14 deletions lib/net/reverseconnection/listen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ func TestInjectAccept(t *testing.T) {
if err != nil {
t.Fatal(err)
}
fakeListener.acceptChannel <- acceptEvent{&Conn{TCPConn: slaveConn}, nil}
fakeListener.acceptChannel <- acceptEvent{
&listenerConn{TCPConn: slaveConn}, nil}
if err := testHttpConnection(masterConn, logger); err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -208,43 +209,71 @@ func TestListenAndHttpServe(t *testing.T) {
func TestReverseListenTcp(t *testing.T) {
tLogger := testlogger.New(t)
// Set up slave.
slaveLogger := prefixlogger.New("slave: ", tLogger)
slaveListener, slavePortNumber := createTestListener(slaveLogger)
slaveAddress := fmt.Sprintf("localhost:%d", slavePortNumber)
slaveListener, slavePortNumber := createTestListener(
prefixlogger.New("slave: ", tLogger))
slaveAddress := fmt.Sprintf("127.0.0.1:%d", slavePortNumber)
t.Logf("slaveAddress: %s", slaveAddress)
// Set up master
masterLogger := prefixlogger.New("master: ", tLogger)
masterListener, masterPortNumber := createTestRealListener(masterLogger)
masterMux := http.NewServeMux()
go http.Serve(masterListener, masterMux)
dialer := NewDialer(nil, masterMux, 0, 0, masterLogger)
// Make slave connect back to master.
slaveLogger.Print("making slave connect to master")
loopbackIP := [4]byte{127, 0, 0, 1}
if slaveListener.connectionMap[loopbackIP] > 0 {
t.Fatalf("slave listener already has %d connections",
slaveListener.connectionMap[loopbackIP])
}
if dialer.connectionMap[slaveAddress] != nil {
t.Fatal("master dialer already has a connection")
}
t.Log("making slave connect to master")
go slaveListener.connectLoop(ReverseListenerConfig{
Network: "tcp",
ServerAddress: fmt.Sprintf("127.0.0.1:%d", masterPortNumber),
MinimumInterval: time.Millisecond,
},
"127.0.0.1")
time.Sleep(time.Millisecond * 5)
masterLogger.Print("making and testing connection")
time.Sleep(time.Millisecond * 50)
if slaveListener.connectionMap[loopbackIP] > 0 {
t.Fatalf(
"slave listener has %d loopback connections, expected none",
slaveListener.connectionMap[loopbackIP])
}
if dialer.connectionMap[slaveAddress] == nil {
t.Fatalf("master dialer does not have a connection yet")
}
time.Sleep(time.Millisecond * 50)
t.Log("making and testing connection")
masterConn, err := dialer.Dial("tcp", slaveAddress)
if err != nil {
masterLogger.Fatal(err)
t.Fatal(err)
}
if dialer.connectionMap[slaveAddress] != nil {
t.Fatal("master dialer still has a connection")
}
slaveConn, err := slaveListener.Accept()
if err != nil {
slaveLogger.Fatal(err)
t.Fatal(err)
}
if _, ok := slaveConn.(libnet.TCPConn); !ok {
slaveLogger.Fatalf("non-TCP connection: %T", slaveConn)
if _, ok := slaveConn.(*listenerConn); !ok {
t.Fatalf("not a *listenerConn connection: %T", slaveConn)
}
if slaveListener.connectionMap[loopbackIP] != 0 {
t.Fatalf("slave listener has %d connections, expected 0",
slaveListener.connectionMap[loopbackIP])
}
if dialer.connectionMap[slaveAddress] != nil {
t.Fatal("master dialer still has a connection")
}
go func() {
if _, err := io.Copy(slaveConn, slaveConn); err != nil {
slaveLogger.Println(err)
t.Log(err)
}
}()
if err := testEcho(masterConn); err != nil {
masterLogger.Fatal(err)
t.Fatal(err)
}
}

Expand All @@ -271,7 +300,7 @@ func TestReverseListenHttp(t *testing.T) {
MinimumInterval: time.Millisecond,
},
"127.0.0.1")
time.Sleep(time.Millisecond * 5)
time.Sleep(time.Millisecond * 50)
masterLogger.Print("making and testing connection")
err := makeAndTestHttpConnection(dialer, slavePortNumber, masterLogger)
if err != nil {
Expand Down
18 changes: 13 additions & 5 deletions lib/net/reverseconnection/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const (

var (
errorNotFound = errors.New("HTTP method not found")
errorLoopback = errors.New("loopback address")
)

func getIp4Address(conn net.Conn) (ip4Address, error) {
Expand All @@ -45,6 +46,9 @@ func getIp4AddressFromAddress(address string) (ip4Address, error) {
if ip == nil {
return ip4Address{}, errors.New("failed to parse: " + address)
}
if ip.IsLoopback() {
return ip4Address{}, errorLoopback
}
ip = ip.To4()
if ip == nil {
return ip4Address{}, errors.New(address + " is not IPv4")
Expand Down Expand Up @@ -76,16 +80,18 @@ func sleep(minInterval, maxInterval time.Duration) {
time.Sleep(minInterval + jit)
}

func (conn *Conn) close() error {
func (conn *listenerConn) Close() error {
if ip, err := getIp4Address(conn); err != nil {
conn.listener.logger.Println(err)
if err != errorLoopback {
conn.listener.logger.Println(err)
}
} else {
conn.listener.forget(conn.RemoteAddr().String(), ip)
}
return conn.TCPConn.Close()
}

func (l *Listener) accept() (*Conn, error) {
func (l *Listener) accept() (*listenerConn, error) {
if l.closed {
return nil, errors.New("listener is closed")
}
Expand Down Expand Up @@ -122,7 +128,8 @@ func (l *Listener) listen(acceptChannel chan<- acceptEvent) {
continue
}
l.remember(conn)
acceptChannel <- acceptEvent{&Conn{TCPConn: tcpConn, listener: l}, err}
acceptChannel <- acceptEvent{
&listenerConn{TCPConn: tcpConn, listener: l}, err}
}
}

Expand Down Expand Up @@ -277,7 +284,8 @@ func (l *Listener) connect(network, serverAddress string, timeout time.Duration,
}
logger.Println("remote has consumed, injecting to local listener")
l.remember(rawConn)
l.acceptChannel <- acceptEvent{&Conn{TCPConn: tcpConn, listener: l}, nil}
l.acceptChannel <- acceptEvent{
&listenerConn{TCPConn: tcpConn, listener: l}, nil}
rawConn = nil // Prevent Close on return.
return &message, nil
}

0 comments on commit a0802cd

Please sign in to comment.