Skip to content

Commit

Permalink
network: fixes to public address support (#5851)
Browse files Browse the repository at this point in the history
* Remove http.Request.RemoteAddr overwriting in request tracker
* Remove http.Request from request tracker
* Add a new remoteAddresss() method providing most meaningful address for incoming requests
  • Loading branch information
algorandskiy committed Dec 7, 2023
1 parent ebd3593 commit 9229066
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 97 deletions.
78 changes: 62 additions & 16 deletions network/requestTracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,27 @@ const (
)

// TrackerRequest hold the tracking data associated with a single request.
// It supposed by an upstream http.Handler called before the wsNetwork's ServeHTTP
// and wsNetwork's Listener (see Accept() method)
type TrackerRequest struct {
created time.Time
remoteHost string
remotePort string
remoteAddr string
request *http.Request
created time.Time
// remoteHost is IP address of the remote host and it is equal to either
// a host part of the remoteAddr or to the value of X-Forwarded-For header (UseXForwardedForAddressField config value).
remoteHost string
// remotePort is the port of the remote peer as reported by the connection or
// by the standard http.Request.RemoteAddr field.
remotePort string
// remoteAddr is IP:Port of the remote host retrieved from the connection
// or from the standard http.Request.RemoteAddr field.
// This field is the real address of the remote incoming connection.
remoteAddr string
// otherPublicAddr is the public address of the other node, as reported by the other node
// via the X-Algorand-Location header.
// It is used for logging and as a rootURL for when creating a new wsPeer from a request.
otherPublicAddr string

otherTelemetryGUID string
otherInstanceName string
otherPublicAddr string
connection net.Conn
noPrune bool
}
Expand All @@ -68,6 +80,43 @@ func makeTrackerRequest(remoteAddr, remoteHost, remotePort string, createTime ti
}
}

// remoteAddress a best guessed remote address for the request.
// Rational is the following:
// remoteAddress() is used either for logging or as rootURL for creating a new wsPeer.
// rootURL is an address to connect to. It is well defined only for peers from a phonebooks,
// and for incoming peers the best guess is either otherPublicAddr, remoteHost, or remoteAddr.
// - otherPublicAddr is provided by a remote peer by X-Algorand-Location header and cannot be trusted,
// but can be used if remoteHost matches to otherPublicAddr value. In this case otherPublicAddr is a better guess
// for a rootURL because it might include a port.
// - remoteHost is either a real address of the remote peer or a value of X-Forwarded-For header.
// Use it if remoteHost was taken from X-Forwarded-For header.
// Note, the remoteHost does not include a port since a listening port is not known.
// - remoteAddr is used otherwise.
func (tr *TrackerRequest) remoteAddress() string {
if len(tr.otherPublicAddr) != 0 {
url, err := ParseHostOrURL(tr.otherPublicAddr)
if err == nil && len(tr.remoteHost) > 0 && url.Hostname() == tr.remoteHost {
return tr.otherPublicAddr
}
}
url, err := ParseHostOrURL(tr.remoteAddr)
if err != nil {
// tr.remoteAddr can't be parsed so try to use tr.remoteHost
// there is a chance it came from a proxy and has a meaningful value
if len(tr.remoteHost) != 0 {
return tr.remoteHost
}
// otherwise fallback to tr.remoteAddr
return tr.remoteAddr
}
if url.Hostname() != tr.remoteHost {
// if remoteAddr's host not equal to remoteHost then the remoteHost
// is definitely came from a proxy, use it
return tr.remoteHost
}
return tr.remoteAddr
}

// hostIncomingRequests holds all the requests that are originating from a single host.
type hostIncomingRequests struct {
remoteHost string
Expand Down Expand Up @@ -142,7 +191,6 @@ func (ard *hostIncomingRequests) add(trackerRequest *TrackerRequest) {
}
// it's going to be added somewhere in the middle.
ard.requests = append(ard.requests[:itemIdx], append([]*TrackerRequest{trackerRequest}, ard.requests[itemIdx:]...)...)
return
}

// countConnections counts the number of connection that we have that occurred after the provided specified time
Expand Down Expand Up @@ -372,7 +420,7 @@ func (rt *RequestTracker) sendBlockedConnectionResponse(conn net.Conn, requestTi
func (rt *RequestTracker) pruneAcceptedConnections(pruneStartDate time.Time) {
localAddrToRemove := []net.Addr{}
for localAddr, request := range rt.acceptedConnections {
if request.noPrune == false && request.created.Before(pruneStartDate) {
if !request.noPrune && request.created.Before(pruneStartDate) {
localAddrToRemove = append(localAddrToRemove, localAddr)
}
}
Expand All @@ -397,7 +445,7 @@ func (rt *RequestTracker) getWaitUntilNoConnectionsChannel(checkInterval time.Du
return len(rt.httpConnections) == 0
}

for true {
for {
if checkEmpty(rt) {
close(done)
return
Expand Down Expand Up @@ -449,7 +497,7 @@ func (rt *RequestTracker) ServeHTTP(response http.ResponseWriter, request *http.
trackedRequest := rt.acceptedConnections[localAddr]
if trackedRequest != nil {
// update the original tracker request so that it won't get pruned.
if trackedRequest.noPrune == false {
if !trackedRequest.noPrune {
trackedRequest.noPrune = true
rt.hostRequests.convertToAdditionalRequest(trackedRequest)
}
Expand All @@ -464,10 +512,9 @@ func (rt *RequestTracker) ServeHTTP(response http.ResponseWriter, request *http.
}

// update the origin address.
rt.updateRequestRemoteAddr(trackedRequest, request)
rt.remoteHostProxyFix(request.Header, trackedRequest)

rt.httpConnectionsMu.Lock()
trackedRequest.request = request
trackedRequest.otherTelemetryGUID, trackedRequest.otherInstanceName, trackedRequest.otherPublicAddr = getCommonHeaders(request.Header)
rt.httpHostRequests.addRequest(trackedRequest)
rt.httpHostRequests.pruneRequests(rateLimitingWindowStartTime)
Expand Down Expand Up @@ -506,13 +553,12 @@ func (rt *RequestTracker) ServeHTTP(response http.ResponseWriter, request *http.

}

// updateRequestRemoteAddr updates the origin IP address in both the trackedRequest as well as in the request.RemoteAddr string
func (rt *RequestTracker) updateRequestRemoteAddr(trackedRequest *TrackerRequest, request *http.Request) {
originIP := rt.getForwardedConnectionAddress(request.Header)
// remoteHostProxyFix updates the origin IP address in the trackedRequest
func (rt *RequestTracker) remoteHostProxyFix(header http.Header, trackedRequest *TrackerRequest) {
originIP := rt.getForwardedConnectionAddress(header)
if originIP == nil {
return
}
request.RemoteAddr = originIP.String() + ":" + trackedRequest.remotePort
trackedRequest.remoteHost = originIP.String()
}

Expand Down
26 changes: 26 additions & 0 deletions network/requestTracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,32 @@ func TestRateLimiting(t *testing.T) {
}
}

func TestRemoteAddress(t *testing.T) {
partitiontest.PartitionTest(t)
t.Parallel()

tr := makeTrackerRequest("127.0.0.1:444", "", "", time.Now(), nil)
require.Equal(t, "127.0.0.1:444", tr.remoteAddr)
require.Equal(t, "127.0.0.1", tr.remoteHost)
require.Equal(t, "444", tr.remotePort)

require.Equal(t, "127.0.0.1:444", tr.remoteAddress())

// remoteHost set to something else via X-Forwared-For HTTP headers
tr.remoteHost = "10.0.0.1"
require.Equal(t, "10.0.0.1", tr.remoteAddress())

// otherPublicAddr is set via X-Algorand-Location HTTP header
// and matches to the remoteHost
tr.otherPublicAddr = "10.0.0.1:555"
require.Equal(t, "10.0.0.1:555", tr.remoteAddress())

// otherPublicAddr does not match remoteHost
tr.remoteHost = "127.0.0.1"
tr.otherPublicAddr = "127.0.0.99:555"
require.Equal(t, "127.0.0.1:444", tr.remoteAddress())
}

func TestIsLocalHost(t *testing.T) {
partitiontest.PartitionTest(t)
t.Parallel()
Expand Down
26 changes: 12 additions & 14 deletions network/wsNetwork.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,9 +377,9 @@ func (wn *WebsocketNetwork) PublicAddress() string {
// If except is not nil then we will not send it to that neighboring Peer.
// if wait is true then the call blocks until the packet has actually been sent to all neighbors.
func (wn *WebsocketNetwork) Broadcast(ctx context.Context, tag protocol.Tag, data []byte, wait bool, except Peer) error {
dataArray := make([][]byte, 1, 1)
dataArray := make([][]byte, 1)
dataArray[0] = data
tagArray := make([]protocol.Tag, 1, 1)
tagArray := make([]protocol.Tag, 1)
tagArray[0] = tag
return wn.broadcaster.BroadcastArray(ctx, tagArray, dataArray, wait, except)
}
Expand Down Expand Up @@ -947,7 +947,7 @@ func (wn *WebsocketNetwork) checkProtocolVersionMatch(otherHeaders http.Header)
// checkIncomingConnectionVariables checks the variables that were provided on the request, and compares them to the
// local server supported parameters. If all good, it returns http.StatusOK; otherwise, it write the error to the ResponseWriter
// and returns the http status.
func (wn *WebsocketNetwork) checkIncomingConnectionVariables(response http.ResponseWriter, request *http.Request) int {
func (wn *WebsocketNetwork) checkIncomingConnectionVariables(response http.ResponseWriter, request *http.Request, remoteAddrForLogging string) int {
// check to see that the genesisID in the request URI is valid and matches the supported one.
pathVars := mux.Vars(request)
otherGenesisID, hasGenesisID := pathVars["genesisID"]
Expand All @@ -958,7 +958,7 @@ func (wn *WebsocketNetwork) checkIncomingConnectionVariables(response http.Respo
}

if wn.GenesisID != otherGenesisID {
wn.log.Warn(filterASCII(fmt.Sprintf("new peer %#v genesis mismatch, mine=%#v theirs=%#v, headers %#v", request.RemoteAddr, wn.GenesisID, otherGenesisID, request.Header)))
wn.log.Warn(filterASCII(fmt.Sprintf("new peer %#v genesis mismatch, mine=%#v theirs=%#v, headers %#v", remoteAddrForLogging, wn.GenesisID, otherGenesisID, request.Header)))
networkConnectionsDroppedTotal.Inc(map[string]string{"reason": "mismatching genesis-id"})
response.WriteHeader(http.StatusPreconditionFailed)
n, err := response.Write([]byte("mismatching genesis ID"))
Expand All @@ -973,7 +973,7 @@ func (wn *WebsocketNetwork) checkIncomingConnectionVariables(response http.Respo
// This is pretty harmless and some configurations of phonebooks or DNS records make this likely. Quietly filter it out.
var message string
// missing header.
wn.log.Warn(filterASCII(fmt.Sprintf("new peer %s did not include random ID header in request. mine=%s headers %#v", request.RemoteAddr, wn.RandomID, request.Header)))
wn.log.Warn(filterASCII(fmt.Sprintf("new peer %s did not include random ID header in request. mine=%s headers %#v", remoteAddrForLogging, wn.RandomID, request.Header)))
networkConnectionsDroppedTotal.Inc(map[string]string{"reason": "missing random ID header"})
message = fmt.Sprintf("Request was missing a %s header", NodeRandomHeader)
response.WriteHeader(http.StatusPreconditionFailed)
Expand All @@ -985,7 +985,7 @@ func (wn *WebsocketNetwork) checkIncomingConnectionVariables(response http.Respo
} else if otherRandom == wn.RandomID {
// This is pretty harmless and some configurations of phonebooks or DNS records make this likely. Quietly filter it out.
var message string
wn.log.Debugf("new peer %s has same node random id, am I talking to myself? %s", request.RemoteAddr, wn.RandomID)
wn.log.Debugf("new peer %s has same node random id, am I talking to myself? %s", remoteAddrForLogging, wn.RandomID)
networkConnectionsDroppedTotal.Inc(map[string]string{"reason": "matching random ID header"})
message = fmt.Sprintf("Request included matching %s=%s header", NodeRandomHeader, otherRandom)
response.WriteHeader(http.StatusLoopDetected)
Expand Down Expand Up @@ -1025,7 +1025,7 @@ func (wn *WebsocketNetwork) ServeHTTP(response http.ResponseWriter, request *htt

matchingVersion, otherVersion := wn.checkProtocolVersionMatch(request.Header)
if matchingVersion == "" {
wn.log.Info(filterASCII(fmt.Sprintf("new peer %s version mismatch, mine=%v theirs=%s, headers %#v", request.RemoteAddr, wn.supportedProtocolVersions, otherVersion, request.Header)))
wn.log.Info(filterASCII(fmt.Sprintf("new peer %s version mismatch, mine=%v theirs=%s, headers %#v", trackedRequest.remoteHost, wn.supportedProtocolVersions, otherVersion, request.Header)))
networkConnectionsDroppedTotal.Inc(map[string]string{"reason": "mismatching protocol version"})
response.WriteHeader(http.StatusPreconditionFailed)
message := fmt.Sprintf("Requested version %s not in %v mismatches server version", filterASCII(otherVersion), wn.supportedProtocolVersions)
Expand All @@ -1036,14 +1036,11 @@ func (wn *WebsocketNetwork) ServeHTTP(response http.ResponseWriter, request *htt
return
}

if wn.checkIncomingConnectionVariables(response, request) != http.StatusOK {
if wn.checkIncomingConnectionVariables(response, request, trackedRequest.remoteAddress()) != http.StatusOK {
// we've already logged and written all response(s).
return
}

// if UseXForwardedForAddressField is not empty, attempt to override the otherPublicAddr with the X Forwarded For origin
trackedRequest.otherPublicAddr = trackedRequest.remoteAddr

responseHeader := make(http.Header)
wn.setHeaders(responseHeader)
responseHeader.Set(ProtocolVersionHeader, matchingVersion)
Expand All @@ -1063,7 +1060,7 @@ func (wn *WebsocketNetwork) ServeHTTP(response http.ResponseWriter, request *htt
peerIDChallenge, peerID, err = wn.identityScheme.VerifyRequestAndAttachResponse(responseHeader, request.Header)
if err != nil {
networkPeerIdentityError.Inc(nil)
wn.log.With("err", err).With("remote", trackedRequest.otherPublicAddr).With("local", localAddr).Warnf("peer (%s) supplied an invalid identity challenge, abandoning peering", trackedRequest.otherPublicAddr)
wn.log.With("err", err).With("remote", trackedRequest.remoteAddress()).With("local", localAddr).Warnf("peer (%s) supplied an invalid identity challenge, abandoning peering", trackedRequest.remoteAddr)
return
}
}
Expand All @@ -1081,7 +1078,7 @@ func (wn *WebsocketNetwork) ServeHTTP(response http.ResponseWriter, request *htt
}

peer := &wsPeer{
wsPeerCore: makePeerCore(wn.ctx, wn, wn.log, wn.handler.readBuffer, trackedRequest.otherPublicAddr, wn.GetRoundTripper(), trackedRequest.remoteHost),
wsPeerCore: makePeerCore(wn.ctx, wn, wn.log, wn.handler.readBuffer, trackedRequest.remoteAddress(), wn.GetRoundTripper(), trackedRequest.remoteHost),
conn: wsPeerWebsocketConnImpl{conn},
outgoing: false,
InstanceName: trackedRequest.otherInstanceName,
Expand All @@ -1097,7 +1094,7 @@ func (wn *WebsocketNetwork) ServeHTTP(response http.ResponseWriter, request *htt
peer.TelemetryGUID = trackedRequest.otherTelemetryGUID
peer.init(wn.config, wn.outgoingMessagesBufferSize)
wn.addPeer(peer)
wn.log.With("event", "ConnectedIn").With("remote", trackedRequest.otherPublicAddr).With("local", localAddr).Infof("Accepted incoming connection from peer %s", trackedRequest.otherPublicAddr)
wn.log.With("event", "ConnectedIn").With("remote", trackedRequest.remoteAddress()).With("local", localAddr).Infof("Accepted incoming connection from peer %s", trackedRequest.remoteAddr)
wn.log.EventWithDetails(telemetryspec.Network, telemetryspec.ConnectPeerEvent,
telemetryspec.PeerEventDetails{
Address: trackedRequest.remoteHost,
Expand Down Expand Up @@ -2047,6 +2044,7 @@ func (wn *WebsocketNetwork) tryConnect(addr, gossipAddr string) {
}
}()
defer wn.wg.Done()

requestHeader := make(http.Header)
wn.setHeaders(requestHeader)
for _, supportedProtocolVersion := range wn.supportedProtocolVersions {
Expand Down

0 comments on commit 9229066

Please sign in to comment.