Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions p2p/kademlia/conn_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package kademlia
import (
"context"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -123,56 +124,85 @@ func NewSecureServerConn(_ context.Context, tc credentials.TransportCredentials,

// Read implements net.Conn's Read interface
func (conn *connWrapper) Read(b []byte) (n int, err error) {
if conn == nil || conn.secureConn == nil {
return 0, io.ErrClosedPipe
}
conn.mtx.Lock()
defer conn.mtx.Unlock()
return conn.secureConn.Read(b)
}

// Write implements net.Conn's Write interface
func (conn *connWrapper) Write(b []byte) (n int, err error) {
if conn == nil || conn.secureConn == nil {
return 0, io.ErrClosedPipe
}
conn.mtx.Lock()
defer conn.mtx.Unlock()
return conn.secureConn.Write(b)
}

// Close implements net.Conn's Close interface
func (conn *connWrapper) Close() error {
if conn == nil {
return nil
}
conn.mtx.Lock()
defer conn.mtx.Unlock()
conn.secureConn.Close()
return conn.rawConn.Close()
if conn.secureConn != nil {
_ = conn.secureConn.Close()
}
if conn.rawConn != nil {
return conn.rawConn.Close()
}
return nil
}

// LocalAddr implements net.Conn's LocalAddr interface
func (conn *connWrapper) LocalAddr() net.Addr {
if conn == nil || conn.rawConn == nil {
return nil
}
conn.mtx.Lock()
defer conn.mtx.Unlock()
return conn.rawConn.LocalAddr()
}

// RemoteAddr implements net.Conn's RemoteAddr interface
func (conn *connWrapper) RemoteAddr() net.Addr {
if conn == nil || conn.rawConn == nil {
return nil
}
conn.mtx.Lock()
defer conn.mtx.Unlock()
return conn.rawConn.RemoteAddr()
}

// SetDeadline implements net.Conn's SetDeadline interface
func (conn *connWrapper) SetDeadline(t time.Time) error {
if conn == nil || conn.secureConn == nil {
return io.ErrClosedPipe
}
conn.mtx.Lock()
defer conn.mtx.Unlock()
return conn.secureConn.SetDeadline(t)
}

// SetReadDeadline implements net.Conn's SetReadDeadline interface
func (conn *connWrapper) SetReadDeadline(t time.Time) error {
if conn == nil || conn.secureConn == nil {
return io.ErrClosedPipe
}
conn.mtx.Lock()
defer conn.mtx.Unlock()
return conn.secureConn.SetReadDeadline(t)
}

// SetWriteDeadline implements net.Conn's SetWriteDeadline interface
func (conn *connWrapper) SetWriteDeadline(t time.Time) error {
if conn == nil || conn.secureConn == nil {
return io.ErrClosedPipe
}
conn.mtx.Lock()
defer conn.mtx.Unlock()
return conn.secureConn.SetWriteDeadline(t)
Expand Down Expand Up @@ -200,6 +230,10 @@ func (pool *ConnPool) Add(addr string, conn net.Conn) {
pool.mtx.Lock()
defer pool.mtx.Unlock()

if w, ok := conn.(*connWrapper); ok && w == nil {
return
}

if item, ok := pool.conns[addr]; ok {
_ = item.conn.Close()
pool.conns[addr] = &connectionItem{lastAccess: time.Now().UTC(), conn: conn}
Expand Down