diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs index 58109b24fa2..6b5cd58ab3b 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs @@ -33,6 +33,7 @@ import Control.Monad.Class.MonadSTM.Strict import Control.Tracer (Tracer, traceWith, contramap) import Data.Foldable (traverse_) import Data.Functor (($>)) +import Data.Function (on) import Data.Maybe (maybeToList) import Data.Proxy (Proxy (..)) import Data.Typeable (Typeable) @@ -110,6 +111,60 @@ data ConnectionManagerArguments handlerTrace socket peerAddr handle handleError } +-- | 'MutableConnState', which supplies a unique identifier. +-- +-- TODO: We can get away without id, by tracking connections in +-- `TerminatingState` using a seprate priority search queue. +-- +data MutableConnState peerAddr handle handleError version m = MutableConnState { + -- | A unique identifier + -- + connStateId :: !Int + + , -- | Mutable state + -- + connVar :: !(StrictTVar m (ConnectionState peerAddr handle handleError + version m)) + } + + +instance Eq (MutableConnState peerAddr handle handleError version m) where + (==) = (==) `on` connStateId + + +-- | A supply of fresh id's. +-- +-- We use a fresh ids for 'MutableConnState'. +-- +newtype FreshIdSupply m = FreshIdSupply { getFreshId :: STM m Int } + + +-- | Create a 'FreshIdSupply' inside and 'STM' monad. +-- +newFreshIdSupply :: forall m. MonadSTM m + => Proxy m -> STM m (FreshIdSupply m) +newFreshIdSupply _ = do + (v :: StrictTVar m Int) <- newTVar 0 + let getFreshId :: STM m Int + getFreshId = do + c <- readTVar v + writeTVar v (succ c) + return c + return $ FreshIdSupply { getFreshId } + + +newMutableConnState :: MonadSTM m + => FreshIdSupply m + -> ConnectionState peerAddr handle handleError + version m + -> STM m (MutableConnState peerAddr handle handleError + version m) +newMutableConnState freshIdSupply connState = do + connStateId <- getFreshId freshIdSupply + connVar <- newTVar connState + return $! MutableConnState { connStateId, connVar } + + -- | 'ConnectionManager' state: for each peer we keep a 'ConnectionState' in -- a mutable variable, which reduce congestion on the 'TMVar' which keeps -- 'ConnectionManagerState'. @@ -119,7 +174,7 @@ data ConnectionManagerArguments handlerTrace socket peerAddr handle handleError -- @peerAddr@ and reuse the 'ConnectionState'. -- type ConnectionManagerState peerAddr handle handleError version m - = Map peerAddr (StrictTVar m (ConnectionState peerAddr handle handleError version m)) + = Map peerAddr (MutableConnState peerAddr handle handleError version m) -- | State of a connection. @@ -366,11 +421,17 @@ withConnectionManager ConnectionManagerArguments { } classifyHandleError k = do - (stateVar :: StrictTMVar m (ConnectionManagerState peerAddr handle handleError version m)) + ((freshIdSupply, stateVar) + :: ( FreshIdSupply m + , StrictTMVar m (ConnectionManagerState peerAddr handle handleError + version m) + )) <- atomically $ do - v <- newTMVar Map.empty + v <- newTMVar Map.empty labelTMVar v "cm-state" - return v + freshIdSupply <- newFreshIdSupply (Proxy :: Proxy m) + return (freshIdSupply, v) + let connectionManager :: ConnectionManager muxMode socket peerAddr handle handleError m connectionManager = @@ -380,7 +441,7 @@ withConnectionManager ConnectionManagerArguments { (WithInitiatorMode OutboundConnectionManager { ocmRequestConnection = - requestOutboundConnectionImpl stateVar outboundHandler, + requestOutboundConnectionImpl freshIdSupply stateVar outboundHandler, ocmUnregisterConnection = unregisterOutboundConnectionImpl stateVar }) @@ -390,7 +451,7 @@ withConnectionManager ConnectionManagerArguments { (WithResponderMode InboundConnectionManager { icmIncludeConnection = - includeInboundConnectionImpl stateVar inboundHandler, + includeInboundConnectionImpl freshIdSupply stateVar inboundHandler, icmUnregisterConnection = unregisterInboundConnectionImpl stateVar, icmPromotedToWarmRemote = @@ -406,13 +467,13 @@ withConnectionManager ConnectionManagerArguments { (WithInitiatorResponderMode OutboundConnectionManager { ocmRequestConnection = - requestOutboundConnectionImpl stateVar outboundHandler, + requestOutboundConnectionImpl freshIdSupply stateVar outboundHandler, ocmUnregisterConnection = unregisterOutboundConnectionImpl stateVar } InboundConnectionManager { icmIncludeConnection = - includeInboundConnectionImpl stateVar inboundHandler, + includeInboundConnectionImpl freshIdSupply stateVar inboundHandler, icmUnregisterConnection = unregisterInboundConnectionImpl stateVar, icmPromotedToWarmRemote = @@ -428,7 +489,7 @@ withConnectionManager ConnectionManagerArguments { traceWith tracer TrShutdown state <- atomically $ readTMVar stateVar traverse_ - (\connVar -> do + (\MutableConnState { connVar } -> do -- cleanup handler for that thread will close socket associated -- with the thread. We put each connection in 'TerminatedState' to -- guarantee, that non of the connection threads will enter @@ -457,7 +518,7 @@ withConnectionManager ConnectionManagerArguments { DuplexState {} -> True TerminatingState {} -> False TerminatedState {} -> False) - <$> traverse readTVar state + <$> traverse (readTVar . connVar) state -- Start connection thread and run connection handler on it. @@ -494,7 +555,7 @@ withConnectionManager ConnectionManagerArguments { wConnVar <- uninterruptibleMask_ $ atomically $ do case Map.lookup peerAddr state of Nothing -> return Nowhere - Just connVar -> do + Just mcs@MutableConnState { connVar } -> do connState <- readTVar connVar case connState of ReservedOutboundState -> do @@ -519,7 +580,7 @@ withConnectionManager ConnectionManagerArguments { writeTVar connVar (TerminatedState Nothing) return $ There connState TerminatingState {} -> do - return $ Here connVar + return $ Here mcs TerminatedState {} -> return $ There connState case wConnVar of @@ -533,10 +594,10 @@ withConnectionManager ConnectionManagerArguments { return ( Map.delete peerAddr state , Left (Known connState) ) - Here connVar -> do + Here mutableConnStateAndTransition -> do close cmSnocket socket return ( state - , Right connVar + , Right mutableConnStateAndTransition ) case mConnVar of @@ -546,7 +607,7 @@ withConnectionManager ConnectionManagerArguments { { fromState = connState , toState = Unknown }) - Right connVar -> + Right (mcs@MutableConnState { connVar }) -> do traceWith tracer (TrConnectionTimeWait connId) when (cmTimeWaitTimeout > 0) $ unmask (threadDelay cmTimeWaitTimeout) @@ -558,7 +619,11 @@ withConnectionManager ConnectionManagerArguments { traceWith tracer (TrConnectionTimeWaitDone connId) trs <- atomically $ do mConnState <- readTMVar stateVar - >>= traverse readTVar . Map.lookup peerAddr + >>= traverse + ( readTVar + . (\MutableConnState { connVar = v } -> v) + ) + . Map.lookup peerAddr -- We can always write to `connVar`, since a new -- connection will use a new 'TVar', but we have to be -- careful when deleting it from 'ConnectionManagerState'. @@ -571,7 +636,7 @@ withConnectionManager ConnectionManagerArguments { . Map.updateLookupWithKey (\_ v -> -- only delete if it wasn't replaced - if eqTVar (Proxy :: Proxy m) connVar v + if mcs == v then Nothing else Just v ) @@ -618,20 +683,22 @@ withConnectionManager ConnectionManagerArguments { includeInboundConnectionImpl :: HasCallStack - => StrictTMVar m (ConnectionManagerState peerAddr handle handleError version m) + => FreshIdSupply m + -> StrictTMVar m (ConnectionManagerState peerAddr handle handleError version m) -> ConnectionHandlerFn handlerTrace socket peerAddr handle handleError version m -> socket -- ^ resource to include in the state -> peerAddr -- ^ remote address used as an identifier of the resource -> m (Connected peerAddr handle handleError) - includeInboundConnectionImpl stateVar + includeInboundConnectionImpl freshIdSupply + stateVar handler socket peerAddr = do let provenance = Inbound traceWith tracer (TrIncludeConnection provenance peerAddr) - (connVar, connId, connThread, reader) + (MutableConnState { connVar }, connId, connThread, reader) <- modifyTMVar stateVar $ \state -> do (reader, writer) <- newEmptyPromiseIO (connId, connThread) @@ -654,9 +721,10 @@ withConnectionManager ConnectionManagerArguments { let connState' = UnnegotiatedState provenance connId connThread (connVar, connState) <- atomically $ do - v <- newTVar connState' - labelTVar v ("conn-state-" ++ show connId) - connState <- traverse readTVar (Map.lookup peerAddr state) + v <- newMutableConnState freshIdSupply connState' + labelTVar (connVar v) ("conn-state-" ++ show connId) + connState <- traverse (readTVar . connVar) + (Map.lookup peerAddr state) return ( v , maybe Unknown Known connState ) @@ -773,7 +841,7 @@ withConnectionManager ConnectionManagerArguments { pure ( Nothing , Nothing , UnsupportedState UnknownConnectionSt ) - Just connVar -> do + Just MutableConnState { connVar } -> do connState <- readTVar connVar case connState of -- In any of the following two states unregistering is not @@ -864,30 +932,31 @@ withConnectionManager ConnectionManagerArguments { requestOutboundConnectionImpl :: HasCallStack - => StrictTMVar m (ConnectionManagerState peerAddr handle handleError version m) + => FreshIdSupply m + -> StrictTMVar m (ConnectionManagerState peerAddr handle handleError version m) -> ConnectionHandlerFn handlerTrace socket peerAddr handle handleError version m -> peerAddr -> m (Connected peerAddr handle handleError) - requestOutboundConnectionImpl stateVar handler peerAddr = do + requestOutboundConnectionImpl freshIdSupply stateVar handler peerAddr = do let provenance = Outbound traceWith tracer (TrIncludeConnection provenance peerAddr) - (trace, connVar, eHandleWedge) <- atomically $ do + (trace, mutableConnState@MutableConnState { connVar }, eHandleWedge) <- atomically $ do state <- readTMVar stateVar case Map.lookup peerAddr state of - Just connVar -> do + Just mutableConnState@MutableConnState { connVar } -> do connState <- readTVar connVar let st = abstractState (Known connState) case connState of ReservedOutboundState -> return ( Just (Right (TrConnectionExists provenance peerAddr st)) - , connVar + , mutableConnState , Left (withCallStack (ConnectionExists provenance peerAddr)) ) UnnegotiatedState Outbound _connId _connThread -> do return ( Just (Right (TrConnectionExists provenance peerAddr st)) - , connVar + , mutableConnState , Left (withCallStack (ConnectionExists provenance peerAddr)) ) @@ -897,27 +966,27 @@ withConnectionManager ConnectionManagerArguments { -- return 'There' to indicate that we need to block on -- the connection state. return ( Nothing - , connVar + , mutableConnState , Right (There connId) ) OutboundUniState {} -> do return ( Just (Right (TrConnectionExists provenance peerAddr st)) - , connVar + , mutableConnState , Left (withCallStack (ConnectionExists provenance peerAddr)) ) OutboundDupState {} -> do return ( Just (Right (TrConnectionExists provenance peerAddr st)) - , connVar + , mutableConnState , Left (withCallStack (ConnectionExists provenance peerAddr)) ) InboundIdleState connId _connThread _handle Unidirectional -> do return ( Just (Right (TrForbiddenConnection connId)) - , connVar + , mutableConnState , Left (withCallStack (ForbiddenConnection connId)) ) @@ -932,7 +1001,7 @@ withConnectionManager ConnectionManagerArguments { return ( Just (Left (TransitionTrace peerAddr (mkTransition connState connState'))) - , connVar + , mutableConnState , Right (Here (Connected connId dataFlow handle)) ) @@ -940,7 +1009,7 @@ withConnectionManager ConnectionManagerArguments { -- the remote side negotiated unidirectional connection, we -- cannot re-use it. return ( Just (Right (TrForbiddenConnection connId)) - , connVar + , mutableConnState , Left (withCallStack (ForbiddenConnection connId)) ) @@ -955,13 +1024,13 @@ withConnectionManager ConnectionManagerArguments { return ( Just (Left (TransitionTrace peerAddr (mkTransition connState connState'))) - , connVar + , mutableConnState , Right (Here (Connected connId dataFlow handle)) ) DuplexState _connId _connThread _handle -> return ( Just (Right (TrConnectionExists provenance peerAddr st)) - , connVar + , mutableConnState , Left (withCallStack (ConnectionExists provenance peerAddr)) ) @@ -979,28 +1048,34 @@ withConnectionManager ConnectionManagerArguments { return ( Just (Left (TransitionTrace peerAddr (mkTransition connState connState'))) - , connVar + , mutableConnState , Right Nowhere ) Nothing -> do let connState' = ReservedOutboundState - connVar <- newTVar connState' + (mutableConnState :: MutableConnState peerAddr handle handleError + version m) + <- newMutableConnState freshIdSupply connState' + -- TODO: label `connVar` using 'ConnectionId' + labelTVar (connVar mutableConnState) ("conn-state-" ++ show peerAddr) + -- record the @connVar@ in 'ConnectionManagerState' we can use -- 'swapTMVar' as we did not use 'takeTMVar' at the beginning of -- this transaction. Since we already 'readTMVar', it will not -- block. - mbConnState <- - swapTMVar stateVar - (Map.insert peerAddr connVar state) - >>= traverse readTVar . Map.lookup peerAddr + (mbConnState + :: Maybe (ConnectionState peerAddr handle handleError version m)) + <- swapTMVar stateVar + (Map.insert peerAddr mutableConnState state) + >>= traverse (readTVar . connVar) . Map.lookup peerAddr return ( Just (Left (TransitionTrace peerAddr Transition { fromState = maybe Unknown Known mbConnState, toState = Known connState' })) - , connVar + , mutableConnState , Right Nowhere ) @@ -1020,10 +1095,10 @@ withConnectionManager ConnectionManagerArguments { let connState' = TerminatedState Nothing writeTVar connVar connState' modifyTMVarPure_ stateVar $ - (Map.update (\connVar' -> - if eqTVar (Proxy :: Proxy m) connVar' connVar + (Map.update (\mutableConnState' -> + if mutableConnState' == mutableConnState then Nothing - else Just connVar') + else Just mutableConnState') peerAddr) return (mkTransition connState connState') traceWith trTracer (TransitionTrace peerAddr tr) @@ -1082,10 +1157,10 @@ withConnectionManager ConnectionManagerArguments { HandshakeProtocolViolation -> TerminatedState (Just handleError) return ( Map.update - (\connVar' -> - if eqTVar (Proxy :: Proxy m) connVar' connVar + (\mutableConnState' -> + if mutableConnState' == mutableConnState then Nothing - else Just connVar') + else Just mutableConnState') peerAddr state , Disconnected connId (Just handleError) @@ -1224,7 +1299,7 @@ withConnectionManager ConnectionManagerArguments { -- Calling 'unregisterOutboundConnection' is a no-op in this case. Nothing -> pure (DemoteToColdLocalNoop (Transition Unknown Unknown)) - Just connVar -> do + Just MutableConnState { connVar } -> do connState <- readTVar connVar case connState of -- In any of the following three states unregistering is not @@ -1310,7 +1385,7 @@ withConnectionManager ConnectionManagerArguments { -- This excludes connections in 'ReservedOutboundState', -- 'TerminatingState' and 'TerminatedState'. (choiseMap :: Map peerAddr (ConnectionType, Async m ())) - <- flip Map.traverseMaybeWithKey state $ \_peerAddr connVar' -> + <- flip Map.traverseMaybeWithKey state $ \_peerAddr MutableConnState { connVar = connVar' } -> (\cs -> -- this expression returns @Maybe (connType, connThread)@; -- 'traverseMaybeWithKey' collects all 'Just' cases. (,) <$> getConnType cs @@ -1374,7 +1449,7 @@ withConnectionManager ConnectionManagerArguments { mbConnVar <- Map.lookup peerAddr <$> readTMVar stateVar case mbConnVar of Nothing -> return (UnsupportedState UnknownConnectionSt) - Just connVar -> do + Just MutableConnState { connVar } -> do connState <- readTVar connVar case connState of ReservedOutboundState {} -> @@ -1430,7 +1505,7 @@ withConnectionManager ConnectionManagerArguments { mbConnVar <- Map.lookup peerAddr <$> readTMVar stateVar case mbConnVar of Nothing -> return (UnsupportedState UnknownConnectionSt) - Just connVar -> do + Just MutableConnState { connVar } -> do connState <- readTVar connVar case connState of ReservedOutboundState {} ->