Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions p2p/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,7 @@ func (p *Peer) handle(msg Msg) error {
case msg.Code == discMsg:
// This is the last message. We don't need to discard or
// check errors because, the connection will be closed after it.
var m struct{ R DiscReason }
rlp.Decode(msg.Payload, &m)
return m.R
return decodeDisconnectMessage(msg.Payload)
case msg.Code < baseProtocolLength:
// ignore other base protocol messages
return msg.Discard()
Expand All @@ -372,6 +370,27 @@ func (p *Peer) handle(msg Msg) error {
return nil
}

// decodeDisconnectMessage decodes the payload of discMsg.
func decodeDisconnectMessage(r io.Reader) (reason DiscReason) {
s := rlp.NewStream(r, 100)
k, _, err := s.Kind()
if err != nil {
return DiscInvalid
}
if k == rlp.List {
s.List()
err = s.Decode(&reason)
} else {
// Legacy path: some implementations, including geth, used to send the disconnect
// reason as a byte array by accident.
err = s.Decode(&reason)
}
if err != nil {
reason = DiscInvalid
}
return reason
}

func countMatchingProtocols(protocols []Protocol, caps []Cap) int {
n := 0
for _, cap := range caps {
Expand Down
5 changes: 4 additions & 1 deletion p2p/peer_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ const (
DiscSelf
DiscReadTimeout
DiscSubprotocolError = DiscReason(0x10)

DiscInvalid = 0xff
)

var discReasonToString = [...]string{
Expand All @@ -86,10 +88,11 @@ var discReasonToString = [...]string{
DiscSelf: "connected to self",
DiscReadTimeout: "read timeout",
DiscSubprotocolError: "subprotocol error",
DiscInvalid: "invalid disconnect reason",
}

func (d DiscReason) String() string {
if len(discReasonToString) <= int(d) {
if len(discReasonToString) <= int(d) || discReasonToString[d] == "" {
return fmt.Sprintf("unknown disconnect reason %d", d)
}
return discReasonToString[d]
Expand Down
24 changes: 10 additions & 14 deletions p2p/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,14 @@ func (t *rlpxTransport) close(err error) {
// Tell the remote end why we're disconnecting if possible.
// We only bother doing this if the underlying connection supports
// setting a timeout tough.
if t.conn != nil {
if r, ok := err.(DiscReason); ok && r != DiscNetworkError {
deadline := time.Now().Add(discWriteTimeout)
if err := t.conn.SetWriteDeadline(deadline); err == nil {
// Connection supports write deadline.
t.wbuf.Reset()
rlp.Encode(&t.wbuf, []DiscReason{r})
t.conn.Write(discMsg, t.wbuf.Bytes())
}
if reason, ok := err.(DiscReason); ok && reason != DiscNetworkError {
// We do not use the WriteMsg func since we want a custom deadline
deadline := time.Now().Add(discWriteTimeout)
if err := t.conn.SetWriteDeadline(deadline); err == nil {
// Connection supports write deadline.
t.wbuf.Reset()
rlp.Encode(&t.wbuf, []any{reason})
t.conn.Write(discMsg, t.wbuf.Bytes())
}
}
t.conn.Close()
Expand Down Expand Up @@ -163,11 +162,8 @@ func readProtocolHandshake(rw MsgReader) (*protoHandshake, error) {
if msg.Code == discMsg {
// Disconnect before protocol handshake is valid according to the
// spec and we send it ourself if the post-handshake checks fail.
// We can't return the reason directly, though, because it is echoed
// back otherwise. Wrap it in a string instead.
var reason [1]DiscReason
rlp.Decode(msg.Payload, &reason)
return nil, reason[0]
r := decodeDisconnectMessage(msg.Payload)
return nil, r
}
if msg.Code != handshakeMsg {
return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
Expand Down
10 changes: 8 additions & 2 deletions p2p/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestProtocolHandshake(t *testing.T) {
return
}

if err := ExpectMsg(rlpx, discMsg, []DiscReason{DiscQuitting}); err != nil {
if err := ExpectMsg(rlpx, discMsg, []any{DiscQuitting}); err != nil {
t.Errorf("error receiving disconnect: %v", err)
}
}()
Expand All @@ -112,7 +112,13 @@ func TestProtocolHandshakeErrors(t *testing.T) {
}{
{
code: discMsg,
msg: []DiscReason{DiscQuitting},
msg: []any{DiscQuitting},
err: DiscQuitting,
},
{
// legacy disconnect encoding as byte array
code: discMsg,
msg: []byte{byte(DiscQuitting)},
err: DiscQuitting,
},
{
Expand Down
Loading