diff --git a/ouroboros-network-framework/ouroboros-network-framework.cabal b/ouroboros-network-framework/ouroboros-network-framework.cabal index 57b5b6af9c4..afaaea4f8af 100644 --- a/ouroboros-network-framework/ouroboros-network-framework.cabal +++ b/ouroboros-network-framework/ouroboros-network-framework.cabal @@ -44,6 +44,7 @@ library Ouroboros.Network.ConnectionManager.Types Ouroboros.Network.ConnectionManager.Core Ouroboros.Network.ConnectionManager.ConnectionHandler + Ouroboros.Network.ConnectionManager.Server Ouroboros.Network.Server.ConnectionTable Ouroboros.Network.Server.Socket Ouroboros.Network.Server.RateLimiting @@ -108,6 +109,7 @@ test-suite ouroboros-network-framework-tests Network.TypedProtocol.ReqResp.Codec.CBOR Test.Network.TypedProtocol.PingPong.Codec Test.Network.TypedProtocol.ReqResp.Codec + Test.Ouroboros.Network.ConnectionManager.Server Test.Ouroboros.Network.Driver Test.Ouroboros.Network.Orphans Test.Ouroboros.Network.Socket @@ -130,6 +132,7 @@ test-suite ouroboros-network-framework-tests , tasty , tasty-quickcheck + , cardano-prelude , contra-tracer , io-sim diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Server.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Server.hs new file mode 100644 index 00000000000..d42d3337a0b --- /dev/null +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Server.hs @@ -0,0 +1,195 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} + +-- | Server implementation based on 'ConnectionManager' +-- +-- TODO: in the futures this should be moved to `Ouroboros.Network.Server`, but +-- to avoid confusion it will be kept here for now. +-- +module Ouroboros.Network.ConnectionManager.Server + ( ServerArguments (..) + , run + -- * Trace + , ServerTrace (..) + , AcceptConnectionsPolicyTrace (..) + + -- * Internals + , peekAlt + ) where + +import Control.Applicative (Alternative (..)) +import Control.Exception (SomeException) +import Control.Monad.Class.MonadAsync +import Control.Monad.Class.MonadSTM +import Control.Monad.Class.MonadThrow +import Control.Monad.Class.MonadTime +import Control.Monad.Class.MonadTimer +import Control.Tracer (Tracer, contramap, traceWith) +import Data.ByteString.Lazy (ByteString) +import Data.Void (Void) +import Data.Foldable (traverse_) +import Data.Functor (void) +import Data.Sequence.Strict (StrictSeq (..), (|>), (><)) +import qualified Data.Sequence.Strict as Seq + +import qualified Network.Mux as Mux + +import Ouroboros.Network.ConnectionManager.Types +import Ouroboros.Network.ConnectionManager.ConnectionHandler +import Ouroboros.Network.Mux +import Ouroboros.Network.Channel (fromChannel) +import Ouroboros.Network.Server.RateLimiting +import Ouroboros.Network.Snocket + + +data ServerArguments (muxMode :: MuxMode) socket peerAddr versionNumber versionDict bytes m = ServerArguments { + serverSocket :: socket, + serverSnocket :: Snocket m socket peerAddr, + serverTracer :: Tracer m (ServerTrace peerAddr), + serverConnectionLimits :: AcceptedConnectionsLimit, + serverConnectionManager :: MuxConnectionManager muxMode socket peerAddr + versionNumber bytes m + } + +run :: forall muxMode socket peerAddr versionNumber versionDict m. + ( MonadAsync m + , MonadCatch m + , MonadDelay m + , MonadTime m + , Mux.HasResponder muxMode ~ True + ) + => ServerArguments muxMode socket peerAddr versionNumber versionDict ByteString m + -> m Void +run ServerArguments { + serverSocket, + serverSnocket, + serverTracer, + serverConnectionLimits, + serverConnectionManager + } = + getLocalAddr serverSnocket serverSocket >>= \localAddr -> do + traceWith serverTracer (ServerStarted localAddr) + muxVars <- newTVarM Seq.Empty + (uncurry (<>)) <$> + (monitoring muxVars) + `concurrently` + (acceptLoop muxVars (accept serverSnocket serverSocket)) + `finally` + traceWith serverTracer (ServerStopped localAddr) + where + -- This is the tricky part of the `monitoring` thread. We want to return + -- the 'a' and the list of all other unresolved transations (otherwise we + -- would leaked memory). It is implemented in terms of 'Alternative' for + -- testing purposes. + peekSTM :: StrictSeq (STM m a) -> STM m (a, StrictSeq (STM m a)) + peekSTM = peekAlt + + monitoring :: TVar m + (StrictSeq + (STM m (MuxPromise muxMode verionNumber ByteString m))) + -> m Void + monitoring muxVars = do + muxPromise <- atomically $ do + muxs <- readTVar muxVars + (muxPromise, muxs') <- peekSTM muxs + writeTVar muxVars muxs' + pure muxPromise + case muxPromise of + MuxRunning mux ptcls _scheduleStopVar -> + traverse_ (runResponder mux) ptcls + _ -> pure () + monitoring muxVars + + + runResponder :: Mux.Mux muxMode m -> MiniProtocol muxMode ByteString m a b -> m () + runResponder mux MiniProtocol { + miniProtocolNum, + miniProtocolRun + } = + case miniProtocolRun of + ResponderProtocolOnly responder -> + void $ + Mux.runMiniProtocol + mux miniProtocolNum + Mux.ResponderDirectionOnly + Mux.StartEagerly + -- TODO: eliminate 'fromChannel' + (runMuxPeer responder . fromChannel) + InitiatorAndResponderProtocol _ responder -> + void $ + Mux.runMiniProtocol + mux miniProtocolNum + Mux.ResponderDirection + Mux.StartEagerly + (runMuxPeer responder . fromChannel) + + + acceptLoop :: TVar m + (StrictSeq + (STM m + (MuxPromise muxMode versionNumber ByteString m))) + -> Accept m SomeException peerAddr socket + -> m Void + acceptLoop muxVars acceptOne = do + runConnectionRateLimits + (ServerAcceptPolicyTrace `contramap` serverTracer) + (numberOfConnections serverConnectionManager) + serverConnectionLimits + result <- runAccept acceptOne + case result of + (AcceptException err, acceptNext) -> do + traceWith serverTracer (ServerAcceptError err) + acceptLoop muxVars acceptNext + (Accepted socket peerAddr, acceptNext) -> do + traceWith serverTracer (ServerAcceptConnection peerAddr) + !muxPromise <- + includeInboundConnection + serverConnectionManager + socket peerAddr + atomically $ modifyTVar muxVars (\as -> as |> muxPromise) + acceptLoop muxVars acceptNext + + +-- +-- Trace +-- + +data ServerTrace peerAddr + = ServerAcceptConnection peerAddr + | ServerAcceptError SomeException + | ServerAcceptPolicyTrace AcceptConnectionsPolicyTrace + | ServerStarted peerAddr + | ServerStopped peerAddr + deriving Show + +-- +-- Internals +-- + +-- | 'peekAlt' finds first non 'empty' element and returns it together with the +-- sequence of all the other ones (preserving their original order). Only the +-- returned non-empty element is dropped from the sequence. It is expressed +-- using 'Alternative' applicative functor, instead of `STM m` for +-- testing purposes. +-- +peekAlt :: Alternative m + => StrictSeq (m a) + -> m (a, StrictSeq (m a)) +peekAlt = go Seq.Empty + where + -- in the cons case we either can resolve 'stm', in which case we + -- return the value together with list of all other transactions, or + -- (`<|>`) we push it on the `acc` and recrurse. + go !acc (stm :<| stms) = + ((\a -> (a, acc >< stms)) <$> stm) + <|> + go (acc |> stm) stms + -- in the 'Empty' case, we just need to 'retry' the trasaction (hence we + -- use 'empty'). + go _acc Seq.Empty = empty diff --git a/ouroboros-network-framework/test/Main.hs b/ouroboros-network-framework/test/Main.hs index 2555ca90e6e..2a02c5a4b73 100644 --- a/ouroboros-network-framework/test/Main.hs +++ b/ouroboros-network-framework/test/Main.hs @@ -4,6 +4,7 @@ import Test.Tasty import qualified Test.Network.TypedProtocol.PingPong.Codec as PingPong import qualified Test.Network.TypedProtocol.ReqResp.Codec as ReqResp +import qualified Test.Ouroboros.Network.ConnectionManager.Server as Server import qualified Test.Ouroboros.Network.Driver as Driver import qualified Test.Ouroboros.Network.Socket as Socket import qualified Test.Ouroboros.Network.Subscription as Subscription @@ -18,6 +19,7 @@ tests = [ PingPong.tests , ReqResp.tests , Driver.tests + , Server.tests , Socket.tests , Subscription.tests , RateLimiting.tests diff --git a/ouroboros-network-framework/test/Test/Ouroboros/Network/ConnectionManager/Server.hs b/ouroboros-network-framework/test/Test/Ouroboros/Network/ConnectionManager/Server.hs new file mode 100644 index 00000000000..b8a4deab78d --- /dev/null +++ b/ouroboros-network-framework/test/Test/Ouroboros/Network/ConnectionManager/Server.hs @@ -0,0 +1,114 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} + +module Test.Ouroboros.Network.ConnectionManager.Server where + +import Control.Applicative +import Control.Monad (MonadPlus, join) +import Data.Foldable (toList) +import Data.Sequence.Strict (StrictSeq) +import qualified Data.Sequence.Strict as Seq + +import Test.QuickCheck +import Test.Tasty (TestTree, testGroup) +import Test.Tasty.QuickCheck (testProperty) + +import qualified Ouroboros.Network.ConnectionManager.Server as Server + + +tests :: TestTree +tests = + testGroup "Ouroboros.Network.ConnectionManager.Server" + [ testGroup "peekAlt" + [ testProperty "foldr (List)" (prop_peekAlt_foldr @[] @Int) + , testProperty "foldr (Maybe)" (prop_peekAlt_foldr @Maybe @Int) + , testProperty "sequence (Maybe)" (prop_peekAlt_sequence @Maybe @Int) + , testProperty "cons (Maybe)" (prop_peekAlt_cons @Maybe @Int) + ] + ] + + +-- +-- peekAlt properties +-- + +-- We are ulitmately interested in this properties for `STM` functor, but we +-- only test them for 'Maybe' monad. This is enough since there is an +-- isomrphism (it preserves 'Alternative' operations) in `Kleisli IO`: +-- +-- > toSTM :: Maybe a -> IO (STM m a) +-- > toSTM Nothing = pure retry +-- > toSTM (Just a) = pure (pure a) +-- +-- with an inverse: +-- +-- > fromSTM :: STM m a -> IO (Maybe a) +-- > fromSTM ma = atomically (ma `orElse` (pure Nothing)) + + +prop_peekAlt_foldr + :: forall m a. + ( Eq (m a) + , Show (m a) + , Alternative m ) + => [m a] -> Property +prop_peekAlt_foldr as = + (fst <$> Server.peekAlt (Seq.fromList as)) + === + (foldr (<|>) empty as) + + +-- | Recursively calling 'peekAlt' is like filtering non 'empty' elements and +-- 'sequence'. +-- +prop_peekAlt_sequence + :: forall m a. + ( Eq (m a) + , Eq (m [a]) + , Eq (m (a, StrictSeq (m a))) + , Show (m [a]) + , MonadPlus m ) + => [m a] -> Property +prop_peekAlt_sequence as = + peekAll [] (Seq.fromList as) + === + sequence (filter (/= empty) as) + where + -- recursievly 'peekAlt' and collect results + peekAll :: [a] -> StrictSeq (m a) -> m [a] + peekAll acc s = + case Server.peekAlt s of + res | res == empty -> pure (reverse acc) + | otherwise -> join $ (\(a, s') -> peekAll (a : acc) s') <$> res + + +-- | Calling `peekAlt` and then cominging the result with a cons ('<|'), should +-- put the first non 'empty' element in front. +-- +prop_peekAlt_cons + :: forall m a. + ( Eq (m a) + , Eq (m [m a]) + , Show (m [m a]) + , Alternative m ) + => [m a] -> Property +prop_peekAlt_cons as = + let x = Server.peekAlt (Seq.fromList as) + + mhead :: m a + mhead = fst <$> x + + mtail :: m (StrictSeq (m a)) + mtail = snd <$> x + + in ((toList . (mhead Seq.<|)) <$> mtail) + === + case span (empty ==) as of + -- if all 'as' entries where 'empty' + (_, []) -> empty + -- otherwise take the first element of `as'`, then list all the empty + -- elements from start of `as`, then the rest of `as'`. + (empties, (a : as')) -> pure (a : empties ++ as')