diff --git a/p2p/kademlia/conn_pool.go b/p2p/kademlia/conn_pool.go index 8c5876f3..66b33e28 100644 --- a/p2p/kademlia/conn_pool.go +++ b/p2p/kademlia/conn_pool.go @@ -3,6 +3,7 @@ package kademlia import ( "context" "fmt" + "io" "net" "sync" "sync/atomic" @@ -123,6 +124,9 @@ func NewSecureServerConn(_ context.Context, tc credentials.TransportCredentials, // Read implements net.Conn's Read interface func (conn *connWrapper) Read(b []byte) (n int, err error) { + if conn == nil || conn.secureConn == nil { + return 0, io.ErrClosedPipe + } conn.mtx.Lock() defer conn.mtx.Unlock() return conn.secureConn.Read(b) @@ -130,6 +134,9 @@ func (conn *connWrapper) Read(b []byte) (n int, err error) { // Write implements net.Conn's Write interface func (conn *connWrapper) Write(b []byte) (n int, err error) { + if conn == nil || conn.secureConn == nil { + return 0, io.ErrClosedPipe + } conn.mtx.Lock() defer conn.mtx.Unlock() return conn.secureConn.Write(b) @@ -137,14 +144,25 @@ func (conn *connWrapper) Write(b []byte) (n int, err error) { // Close implements net.Conn's Close interface func (conn *connWrapper) Close() error { + if conn == nil { + return nil + } conn.mtx.Lock() defer conn.mtx.Unlock() - conn.secureConn.Close() - return conn.rawConn.Close() + if conn.secureConn != nil { + _ = conn.secureConn.Close() + } + if conn.rawConn != nil { + return conn.rawConn.Close() + } + return nil } // LocalAddr implements net.Conn's LocalAddr interface func (conn *connWrapper) LocalAddr() net.Addr { + if conn == nil || conn.rawConn == nil { + return nil + } conn.mtx.Lock() defer conn.mtx.Unlock() return conn.rawConn.LocalAddr() @@ -152,6 +170,9 @@ func (conn *connWrapper) LocalAddr() net.Addr { // RemoteAddr implements net.Conn's RemoteAddr interface func (conn *connWrapper) RemoteAddr() net.Addr { + if conn == nil || conn.rawConn == nil { + return nil + } conn.mtx.Lock() defer conn.mtx.Unlock() return conn.rawConn.RemoteAddr() @@ -159,6 +180,9 @@ func (conn *connWrapper) RemoteAddr() net.Addr { // SetDeadline implements net.Conn's SetDeadline interface func (conn *connWrapper) SetDeadline(t time.Time) error { + if conn == nil || conn.secureConn == nil { + return io.ErrClosedPipe + } conn.mtx.Lock() defer conn.mtx.Unlock() return conn.secureConn.SetDeadline(t) @@ -166,6 +190,9 @@ func (conn *connWrapper) SetDeadline(t time.Time) error { // SetReadDeadline implements net.Conn's SetReadDeadline interface func (conn *connWrapper) SetReadDeadline(t time.Time) error { + if conn == nil || conn.secureConn == nil { + return io.ErrClosedPipe + } conn.mtx.Lock() defer conn.mtx.Unlock() return conn.secureConn.SetReadDeadline(t) @@ -173,6 +200,9 @@ func (conn *connWrapper) SetReadDeadline(t time.Time) error { // SetWriteDeadline implements net.Conn's SetWriteDeadline interface func (conn *connWrapper) SetWriteDeadline(t time.Time) error { + if conn == nil || conn.secureConn == nil { + return io.ErrClosedPipe + } conn.mtx.Lock() defer conn.mtx.Unlock() return conn.secureConn.SetWriteDeadline(t) @@ -200,6 +230,10 @@ func (pool *ConnPool) Add(addr string, conn net.Conn) { pool.mtx.Lock() defer pool.mtx.Unlock() + if w, ok := conn.(*connWrapper); ok && w == nil { + return + } + if item, ok := pool.conns[addr]; ok { _ = item.conn.Close() pool.conns[addr] = &connectionItem{lastAccess: time.Now().UTC(), conn: conn}