diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Protocol/Handshake.hs b/ouroboros-network-framework/src/Ouroboros/Network/Protocol/Handshake.hs index 1e043b1d135..3e8dee3dada 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Protocol/Handshake.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Protocol/Handshake.hs @@ -94,7 +94,11 @@ data HandshakeArguments connectionId vNumber vData m application = HandshakeArgu :: VersionDataCodec CBOR.Term vNumber vData, -- | versioned application aggreed upon with the 'Handshake' protocol. - haVersions :: Versions vNumber vData application + haVersions :: Versions vNumber vData application, + + -- | accept version, first argument is our version data the second + -- argument is the remote version data. + haAcceptVersion :: vData -> vData -> Accept vData } @@ -111,18 +115,17 @@ runHandshakeClient ) => MuxBearer m -> connectionId - -> (vData -> vData -> Accept vData) -> HandshakeArguments connectionId vNumber vData m application -> m (Either (HandshakeException (HandshakeClientProtocolError vNumber)) (application, vNumber, vData)) runHandshakeClient bearer connectionId - acceptVersion HandshakeArguments { haHandshakeTracer, haHandshakeCodec, haVersionDataCodec, - haVersions + haVersions, + haAcceptVersion } = tryHandshake (fst <$> @@ -132,7 +135,7 @@ runHandshakeClient bearer byteLimitsHandshake timeLimitsHandshake (fromChannel (muxBearerAsChannel bearer handshakeProtocolNum InitiatorDir)) - (handshakeClientPeer haVersionDataCodec acceptVersion haVersions)) + (handshakeClientPeer haVersionDataCodec haAcceptVersion haVersions)) -- | Run server side of the 'Handshake' protocol. @@ -148,19 +151,18 @@ runHandshakeServer ) => MuxBearer m -> connectionId - -> (vData -> vData -> Accept vData) -> HandshakeArguments connectionId vNumber vData m application -> m (Either (HandshakeException (RefuseReason vNumber)) (application, vNumber, vData)) runHandshakeServer bearer connectionId - acceptVersion HandshakeArguments { haHandshakeTracer, haHandshakeCodec, haVersionDataCodec, - haVersions + haVersions, + haAcceptVersion } = tryHandshake (fst <$> @@ -170,4 +172,4 @@ runHandshakeServer bearer byteLimitsHandshake timeLimitsHandshake (fromChannel (muxBearerAsChannel bearer handshakeProtocolNum ResponderDir)) - (handshakeServerPeer haVersionDataCodec acceptVersion haVersions)) + (handshakeServerPeer haVersionDataCodec haAcceptVersion haVersions)) diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Socket.hs b/ouroboros-network-framework/src/Ouroboros/Network/Socket.hs index c6735035a41..d0822d7a4cf 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Socket.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Socket.hs @@ -248,13 +248,13 @@ connectToNode' sn handshakeCodec versionDataCodec NetworkConnectTracers {nctMuxT runHandshakeClient (Snocket.toBearer sn sduHandshakeTimeout muxTracer sd) connectionId - acceptVersion -- TODO: push 'HandshakeArguments' up the call stack. HandshakeArguments { haHandshakeTracer = nctHandshakeTracer, haHandshakeCodec = handshakeCodec, haVersionDataCodec = versionDataCodec, - haVersions = versions + haVersions = versions, + haAcceptVersion = acceptVersion } ts_end <- getMonotonicTime case app_e of @@ -368,12 +368,12 @@ beginConnection sn muxTracer handshakeTracer handshakeCodec versionDataCodec acc runHandshakeServer (Snocket.toBearer sn sduHandshakeTimeout muxTracer' sd) connectionId - acceptVersion HandshakeArguments { haHandshakeTracer = handshakeTracer, haHandshakeCodec = handshakeCodec, haVersionDataCodec = versionDataCodec, - haVersions = versions + haVersions = versions, + haAcceptVersion = acceptVersion } case app_e of