diff --git a/ouroboros-network-framework/src/Simulation/Network/Snocket.hs b/ouroboros-network-framework/src/Simulation/Network/Snocket.hs index 2556e4c6a41..77fd460aa28 100644 --- a/ouroboros-network-framework/src/Simulation/Network/Snocket.hs +++ b/ouroboros-network-framework/src/Simulation/Network/Snocket.hs @@ -108,20 +108,27 @@ data Connection m = Connection -- | Opening state of a connection. This is used to detect simultaneous -- open. -- - , connState :: !OpenState + , connState :: !ConnectionState } -data OpenState - -- | Half opened connection is after calling `connect` but before the other - -- side picked it. Either using simultaneous open or normal open. - = HalfOpened +-- | Connection state as seen by the network environment. We borrow TCP state +-- names, but be aware that these states, unlike in TCP, are not local to the +-- service point. +-- +data ConnectionState + -- | SYN_SENT connection state: after calling `connect` but before the + -- other side accepted it: either as a simultaneous open or normal open. + -- + = SYN_SENT -- | This corresponds to established state of a tcp connection. - | Established + -- + | ESTABLISHED - -- | Half closed connection. - | HalfClosed + -- | Half opened connection. + -- + | FIN deriving (Eq, Show) @@ -168,7 +175,7 @@ mkConnection tr bearerInfo connId@ConnectionId { localAddress, remoteAddress } = return $ Connection channelLocal channelRemote (biSDUSize bearerInfo) - HalfOpened + SYN_SENT -- | Connection id independent of who provisioned the connection. 'NormalisedId' @@ -308,7 +315,7 @@ newNetworkState bearerInfoScript = atomically $ do data ResourceException addr = NotReleasedListeningSockets [addr] (Maybe SomeException) - | NotReleasedConnections (Map (NormalisedId addr) OpenState) + | NotReleasedConnections (Map (NormalisedId addr) ConnectionState) (Maybe SomeException) deriving (Show, Typeable) @@ -424,6 +431,7 @@ data FD_ m addr -- | 'FD_' for snockets in listening state. -- -- 'FDListening' is created by 'listen' + -- | FDListening !addr -- ^ listening address @@ -437,14 +445,16 @@ data FD_ m addr -- 'accept' is the consumer. -- | 'FD_' was passed to 'connect' call, if needed an ephemeral address was - -- assigned to it. + -- assigned to it. This corresponds to 'SYN_SENT' state. -- | FDConnecting !(ConnectionId addr) !(Connection m) -- | 'FD_' for snockets in connected state. -- - -- 'FDConnected' is created by either 'connect' or 'accept'. + -- 'FDConnected' is created by either 'connect' or 'accept'. It + -- corresponds to 'ESTABLISHED' state. + -- | FDConnected !(ConnectionId addr) -- ^ local and remote addresses @@ -517,7 +527,7 @@ data SnocketTrace m addr | STConnectTimeout TimeoutDetail | STBindError (FD_ m addr) addr IOError | STClosing SockType (Wedge (ConnectionId addr) [addr]) - | STClosed SockType (Maybe (Maybe OpenState)) + | STClosed SockType (Maybe (Maybe ConnectionState)) -- ^ TODO: Document meaning of 'Maybe (Maybe OpenState)' | STClosingQueue Bool | STClosedQueue Bool @@ -691,10 +701,12 @@ mkSnocket state tr = Snocket { getLocalAddr conMap <- readTVar (nsConnections state) case Map.lookup (normaliseId connId) conMap of - Just Connection { connState = Established } -> + Just Connection { connState = ESTABLISHED } -> throwSTM (connectedIOError fd_) - Just conn@Connection { connState = HalfOpened } -> do - let conn' = conn { connState = Established } + + -- simultaneous open + Just conn@Connection { connState = SYN_SENT } -> do + let conn' = conn { connState = ESTABLISHED } writeTVar fdVarLocal (FDConnecting connId conn') modifyTVar (nsConnections state) (Map.adjust (const conn') @@ -703,8 +715,10 @@ mkSnocket state tr = Snocket { getLocalAddr , connId , bearerInfo ) - Just Connection { connState = HalfClosed } -> + + Just Connection { connState = FIN } -> throwSTM (connectedIOError fd_) + Nothing -> do conn <- mkConnection tr bearerInfo connId writeTVar fdVarLocal (FDConnecting connId conn) @@ -765,10 +779,10 @@ mkSnocket state tr = Snocket { getLocalAddr writeTVar fdVarLocal fd_' mConn <- Map.lookup (normaliseId connId) <$> readTVar (nsConnections state) case mConn of - Just Connection { connState = Established } -> + Just Connection { connState = ESTABLISHED } -> -- successful simultaneous open return (Right (fd_', NormalOpen)) - Just Connection { connState = HalfOpened } -> do + Just Connection { connState = SYN_SENT } -> do writeTBQueue queue ChannelWithInfo { cwiAddress = localAddress connId @@ -777,8 +791,8 @@ mkSnocket state tr = Snocket { getLocalAddr , cwiChannelRemote = connChannelLocal } return (Right (fd_', NormalOpen)) - Just Connection { connState = HalfClosed } -> do - return (Left (connectIOError connId "connect error (half-closed)")) + Just Connection { connState = FIN } -> do + return (Left (connectIOError connId "connect error (FIN)")) Nothing -> return (Left (connectIOError connId "connect error")) @@ -812,7 +826,7 @@ mkSnocket state tr = Snocket { getLocalAddr $ "unknown connection: " ++ show (normaliseId connId) Just Connection { connState } -> - Just <$> check (connState == Established)) + Just <$> check (connState == ESTABLISHED)) ) `onException` atomically (modifyTVar (nsConnections state) @@ -947,25 +961,19 @@ mkSnocket state tr = Snocket { getLocalAddr (TestAddress addr)) accept FD { fdVar } = pure accept_ where - readTBQueueUntil :: (a -> STM m Bool) -> TBQueue m a -> STM m a - readTBQueueUntil p queue = do - a <- readTBQueue queue - shouldReturn <- p a - if shouldReturn - then return a - else readTBQueueUntil p queue - - isHalfOpened :: TestAddress addr - -> ChannelWithInfo m (TestAddress addr) - -> STM m Bool - isHalfOpened localAddress cwi = do + -- non-blocking; return 'True' if a connection is in 'SYN_SENT' state + -- or if it was removed from simulation state. + synSentOrUnknown :: TestAddress addr + -> ChannelWithInfo m (TestAddress addr) + -> STM m Bool + synSentOrUnknown localAddress cwi = do connMap <- readTVar (nsConnections state) let connId = ConnectionId localAddress (cwiAddress cwi) case Map.lookup (normaliseId connId) connMap of - Nothing -> return True - Just (Connection _ _ _ HalfOpened) -> return True - _ -> return False + Nothing -> return True + Just (Connection _ _ _ SYN_SENT) -> return True + _ -> return False accept_ = Accept $ \unmask -> do bracketOnError @@ -994,11 +1002,11 @@ mkSnocket state tr = Snocket { getLocalAddr FDListening localAddress queue -> do -- We should not accept nor fail the 'accept' call -- in the presence of a connection that is in - -- HalfOpened state. So we take from the TBQueue - -- until we have found one that is __not__ in HalfOpened + -- SYN_SENT state. So we take from the TBQueue + -- until we have found one that is __not__ in SYN_SENT -- state. cwi <- readTBQueueUntil - (isHalfOpened localAddress) + (synSentOrUnknown localAddress) queue let connId = ConnectionId localAddress (cwiAddress cwi) @@ -1036,7 +1044,7 @@ mkSnocket state tr = Snocket { getLocalAddr fdRemote <- atomically $ do modifyTVar (nsConnections state) - (Map.adjust (\s -> s { connState = Established }) + (Map.adjust (\s -> s { connState = ESTABLISHED }) (normaliseId connId)) FD <$> newTVar (FDConnected @@ -1045,7 +1053,7 @@ mkSnocket state tr = Snocket { getLocalAddr { connChannelLocal = channelLocal , connChannelRemote = channelRemote , connSDUSize = sduSize - , connState = Established + , connState = ESTABLISHED }) traceWith tr (WithAddr (Just (localAddress connId)) Nothing @@ -1118,10 +1126,10 @@ mkSnocket state tr = Snocket { getLocalAddr (Map.update (\conn@Connection { connState } -> case connState of - HalfClosed -> + FIN -> Nothing _ -> - Just conn { connState = HalfClosed }) + Just conn { connState = FIN }) (normaliseId connId))) (\(addr, _, _) -> modifyTVar (nsListeningFDs state) @@ -1195,3 +1203,17 @@ drainTBQueue q = do Nothing -> return [] Just a -> (a :) <$> drainTBQueue q + +-- | Return first element which satisfy the given predicate. +-- +readTBQueueUntil :: MonadSTMTx stm + => (a -> stm Bool) -- ^ a monadic predicate + -> TBQueue_ stm a -- ^ queue + -> stm a +readTBQueueUntil p q = do + a <- readTBQueue q + b <- p a + if b + then return a + else readTBQueueUntil p q +