From 70edb92d9d1cae0e363063129979c57cf2ac461a Mon Sep 17 00:00:00 2001 From: Edsko de Vries Date: Thu, 30 Jan 2020 12:05:58 +0100 Subject: [PATCH] Improve support for monad stacks The main change in this commit is split in MonadSTM, which avoids the injectivity requirement, enabling GeneralizedNewtypeDeriving for MonadSTM. The remainder of the changes are related, and similarly intended to faciliate the derivation of monad stacks. --- .../src/Control/Monad/Class/MonadAsync.hs | 193 ++++++++++------ .../src/Control/Monad/Class/MonadST.hs | 5 +- .../src/Control/Monad/Class/MonadSTM.hs | 215 ++++++++++-------- .../Control/Monad/Class/MonadSTM/Strict.hs | 22 +- .../src/Control/Monad/Class/MonadTime.hs | 11 +- .../src/Control/Monad/Class/MonadTimer.hs | 51 +++-- io-sim/src/Control/Monad/IOSim.hs | 75 +++--- io-sim/test/Test/IOSim.hs | 19 +- io-sim/test/Test/STM.hs | 34 +-- .../src/Ouroboros/Consensus/NodeKernel.hs | 9 +- .../src/Ouroboros/Consensus/Util/EarlyExit.hs | 127 +++++------ .../src/Ouroboros/Consensus/Util/IOLike.hs | 1 + .../src/Ouroboros/Consensus/Util/STM.hs | 20 +- .../Ouroboros/Storage/ChainDB/Impl/Reader.hs | 6 +- .../Ouroboros/Storage/Util/ErrorHandling.hs | 3 +- .../test-consensus/Test/ThreadNet/Network.hs | 6 +- .../Test/Ouroboros/Storage/VolatileDB/Mock.hs | 4 +- .../Ouroboros/Network/BlockFetch/Examples.hs | 2 +- .../Ouroboros/Network/PeerSelection/Test.hs | 64 +++--- ouroboros-network/test/Test/Pipe.hs | 3 +- ouroboros-network/test/Test/Socket.hs | 8 +- ouroboros-network/test/Test/Subscription.hs | 27 +-- 22 files changed, 500 insertions(+), 405 deletions(-) diff --git a/io-sim-classes/src/Control/Monad/Class/MonadAsync.hs b/io-sim-classes/src/Control/Monad/Class/MonadAsync.hs index 91923970780..93b395a50b9 100644 --- a/io-sim-classes/src/Control/Monad/Class/MonadAsync.hs +++ b/io-sim-classes/src/Control/Monad/Class/MonadAsync.hs @@ -1,12 +1,14 @@ -{-# LANGUAGE DefaultSignatures #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} module Control.Monad.Class.MonadAsync ( MonadAsync (..) + , MonadAsyncSTM (..) , AsyncCancelled(..) , ExceptionInLinkedThread(..) , link @@ -21,25 +23,81 @@ import Control.Monad.Class.MonadFork import Control.Monad.Class.MonadSTM import Control.Monad.Class.MonadThrow -import Control.Monad (void) +import Control.Concurrent.Async (AsyncCancelled (..)) +import qualified Control.Concurrent.Async as Async import Control.Exception (SomeException) import qualified Control.Exception as E -import qualified Control.Concurrent.Async as Async -import Control.Concurrent.Async (AsyncCancelled(..)) +import Control.Monad (void) +import Control.Monad.Reader +import qualified Control.Monad.STM as STM import Data.Proxy +class (Functor async, MonadSTMTx stm) => MonadAsyncSTM async stm where + {-# MINIMAL waitCatchSTM, pollSTM #-} + + waitSTM :: async a -> stm a + pollSTM :: async a -> stm (Maybe (Either SomeException a)) + waitCatchSTM :: async a -> stm (Either SomeException a) + + default waitSTM :: MonadThrow stm => async a -> stm a + waitSTM action = waitCatchSTM action >>= either throwM return + + waitAnySTM :: [async a] -> stm (async a, a) + waitAnyCatchSTM :: [async a] -> stm (async a, Either SomeException a) + waitEitherSTM :: async a -> async b -> stm (Either a b) + waitEitherSTM_ :: async a -> async b -> stm () + waitEitherCatchSTM :: async a -> async b + -> stm (Either (Either SomeException a) + (Either SomeException b)) + waitBothSTM :: async a -> async b -> stm (a, b) + + default waitAnySTM :: MonadThrow stm => [async a] -> stm (async a, a) + default waitEitherSTM :: MonadThrow stm => async a -> async b -> stm (Either a b) + default waitEitherSTM_ :: MonadThrow stm => async a -> async b -> stm () + default waitBothSTM :: MonadThrow stm => async a -> async b -> stm (a, b) + + waitAnySTM as = + foldr orElse retry $ + map (\a -> do r <- waitSTM a; return (a, r)) as + + waitAnyCatchSTM as = + foldr orElse retry $ + map (\a -> do r <- waitCatchSTM a; return (a, r)) as + + waitEitherSTM left right = + (Left <$> waitSTM left) + `orElse` + (Right <$> waitSTM right) + + waitEitherSTM_ left right = + (void $ waitSTM left) + `orElse` + (void $ waitSTM right) + + waitEitherCatchSTM left right = + (Left <$> waitCatchSTM left) + `orElse` + (Right <$> waitCatchSTM right) + + waitBothSTM left right = do + a <- waitSTM left + `orElse` + (waitSTM right >> retry) + b <- waitSTM right + return (a,b) + class ( MonadSTM m , MonadThread m - , Functor (Async m) + , MonadAsyncSTM (Async m) (STM m) ) => MonadAsync m where - {-# MINIMAL async, asyncThreadId, cancel, cancelWith, waitCatchSTM, pollSTM #-} + {-# MINIMAL async, asyncThreadId, cancel, cancelWith #-} -- | An asynchronous action type Async m :: * -> * async :: m a -> m (Async m a) - asyncThreadId :: proxy m -> Async m a -> ThreadId m + asyncThreadId :: Proxy m -> Async m a -> ThreadId m withAsync :: m a -> (Async m a -> m b) -> m b wait :: Async m a -> m a @@ -49,10 +107,6 @@ class ( MonadSTM m cancelWith :: Exception e => Async m a -> e -> m () uninterruptibleCancel :: Async m a -> m () - waitSTM :: Async m a -> STM m a - pollSTM :: Async m a -> STM m (Maybe (Either SomeException a)) - waitCatchSTM :: Async m a -> STM m (Either SomeException a) - waitAny :: [Async m a] -> m (Async m a, a) waitAnyCatch :: [Async m a] -> m (Async m a, Either SomeException a) waitAnyCancel :: [Async m a] -> m (Async m a, a) @@ -70,15 +124,6 @@ class ( MonadSTM m waitEither_ :: Async m a -> Async m b -> m () waitBoth :: Async m a -> Async m b -> m (a, b) - waitAnySTM :: [Async m a] -> STM m (Async m a, a) - waitAnyCatchSTM :: [Async m a] -> STM m (Async m a, Either SomeException a) - waitEitherSTM :: Async m a -> Async m b -> STM m (Either a b) - waitEitherSTM_ :: Async m a -> Async m b -> STM m () - waitEitherCatchSTM :: Async m a -> Async m b - -> STM m (Either (Either SomeException a) - (Either SomeException b)) - waitBothSTM :: Async m a -> Async m b -> STM m (a, b) - race :: m a -> m b -> m (Either a b) race_ :: m a -> m b -> m () concurrently :: m a -> m b -> m (a,b) @@ -87,7 +132,6 @@ class ( MonadSTM m default withAsync :: MonadMask m => m a -> (Async m a -> m b) -> m b default uninterruptibleCancel :: MonadMask m => Async m a -> m () - default waitSTM :: MonadThrow (STM m) => Async m a -> STM m a default waitAnyCancel :: MonadThrow m => [Async m a] -> m (Async m a, a) default waitAnyCatchCancel :: MonadThrow m => [Async m a] -> m (Async m a, Either SomeException a) @@ -97,12 +141,6 @@ class ( MonadSTM m -> m (Either (Either SomeException a) (Either SomeException b)) - default waitAnySTM :: MonadThrow (STM m) => [Async m a] -> STM m (Async m a, a) - default waitEitherSTM :: MonadThrow (STM m) => Async m a -> Async m b -> STM m (Either a b) - default waitEitherSTM_ :: MonadThrow (STM m) => Async m a -> Async m b -> STM m () - default waitBothSTM :: MonadThrow (STM m) => Async m a -> Async m b -> STM m (a, b) - - withAsync action inner = mask $ \restore -> do a <- async (restore action) restore (inner a) @@ -113,7 +151,6 @@ class ( MonadSTM m waitCatch = atomically . waitCatchSTM uninterruptibleCancel = uninterruptibleMask_ . cancel - waitSTM action = waitCatchSTM action >>= either throwM return waitAny = atomically . waitAnySTM waitAnyCatch = atomically . waitAnyCatchSTM @@ -134,36 +171,6 @@ class ( MonadSTM m waitEitherCatchCancel left right = waitEitherCatch left right `finally` (cancel left >> cancel right) - waitAnySTM as = - foldr orElse retry $ - map (\a -> do r <- waitSTM a; return (a, r)) as - - waitAnyCatchSTM as = - foldr orElse retry $ - map (\a -> do r <- waitCatchSTM a; return (a, r)) as - - waitEitherSTM left right = - (Left <$> waitSTM left) - `orElse` - (Right <$> waitSTM right) - - waitEitherSTM_ left right = - (void $ waitSTM left) - `orElse` - (void $ waitSTM right) - - waitEitherCatchSTM left right = - (Left <$> waitCatchSTM left) - `orElse` - (Right <$> waitCatchSTM right) - - waitBothSTM left right = do - a <- waitSTM left - `orElse` - (waitSTM right >> retry) - b <- waitSTM right - return (a,b) - race left right = withAsync left $ \a -> withAsync right $ \b -> waitEither a b @@ -180,6 +187,17 @@ class ( MonadSTM m -- Instance for IO uses the existing async library implementations -- +instance MonadAsyncSTM Async.Async STM.STM where + waitSTM = Async.waitSTM + pollSTM = Async.pollSTM + waitCatchSTM = Async.waitCatchSTM + waitAnySTM = Async.waitAnySTM + waitAnyCatchSTM = Async.waitAnyCatchSTM + waitEitherSTM = Async.waitEitherSTM + waitEitherSTM_ = Async.waitEitherSTM_ + waitEitherCatchSTM = Async.waitEitherCatchSTM + waitBothSTM = Async.waitBothSTM + instance MonadAsync IO where type Async IO = Async.Async @@ -195,10 +213,6 @@ instance MonadAsync IO where cancelWith = Async.cancelWith uninterruptibleCancel = Async.uninterruptibleCancel - waitSTM = Async.waitSTM - pollSTM = Async.pollSTM - waitCatchSTM = Async.waitCatchSTM - waitAny = Async.waitAny waitAnyCatch = Async.waitAnyCatch waitAnyCancel = Async.waitAnyCancel @@ -210,17 +224,46 @@ instance MonadAsync IO where waitEither_ = Async.waitEither_ waitBoth = Async.waitBoth - waitAnySTM = Async.waitAnySTM - waitAnyCatchSTM = Async.waitAnyCatchSTM - waitEitherSTM = Async.waitEitherSTM - waitEitherSTM_ = Async.waitEitherSTM_ - waitEitherCatchSTM = Async.waitEitherCatchSTM - waitBothSTM = Async.waitBothSTM - race = Async.race race_ = Async.race_ concurrently = Async.concurrently +-- +-- Lift to ReaderT +-- + +(.:) :: (c -> d) -> (a -> b -> c) -> (a -> b -> d) +(f .: g) x y = f (g x y) + +instance MonadAsync m => MonadAsync (ReaderT r m) where + type Async (ReaderT r m) = Async m + + asyncThreadId _ = asyncThreadId (Proxy @m) + + async (ReaderT ma) = ReaderT $ \r -> async (ma r) + withAsync (ReaderT ma) f = ReaderT $ \r -> withAsync (ma r) $ \a -> runReaderT (f a) r + + race (ReaderT ma) (ReaderT mb) = ReaderT $ \r -> race (ma r) (mb r) + race_ (ReaderT ma) (ReaderT mb) = ReaderT $ \r -> race_ (ma r) (mb r) + concurrently (ReaderT ma) (ReaderT mb) = ReaderT $ \r -> concurrently (ma r) (mb r) + + wait = lift . wait + poll = lift . poll + waitCatch = lift . waitCatch + cancel = lift . cancel + uninterruptibleCancel = lift . uninterruptibleCancel + cancelWith = lift .: cancelWith + waitAny = lift . waitAny + waitAnyCatch = lift . waitAnyCatch + waitAnyCancel = lift . waitAnyCancel + waitAnyCatchCancel = lift . waitAnyCatchCancel + waitEither = lift .: waitEither + waitEitherCatch = lift .: waitEitherCatch + waitEitherCancel = lift .: waitEitherCancel + waitEitherCatchCancel = lift .: waitEitherCatchCancel + waitEither_ = lift .: waitEither_ + waitBoth = lift .: waitBoth + -- -- Linking -- @@ -275,7 +318,7 @@ linkToOnly tid shouldThrow a = do r <- waitCatch a case r of Left e | shouldThrow e -> throwTo tid (exceptionInLinkedThread e) - _otherwise -> return () + _otherwise -> return () where exceptionInLinkedThread :: SomeException -> ExceptionInLinkedThread exceptionInLinkedThread = diff --git a/io-sim-classes/src/Control/Monad/Class/MonadST.hs b/io-sim-classes/src/Control/Monad/Class/MonadST.hs index f2457915ee3..04319cf7808 100644 --- a/io-sim-classes/src/Control/Monad/Class/MonadST.hs +++ b/io-sim-classes/src/Control/Monad/Class/MonadST.hs @@ -1,7 +1,8 @@ {-# LANGUAGE RankNTypes #-} module Control.Monad.Class.MonadST where -import Control.Monad.ST (ST, stToIO) +import Control.Monad.Reader +import Control.Monad.ST (ST, stToIO) -- | This class is for abstracting over 'stToIO' which allows running 'ST' @@ -29,3 +30,5 @@ instance MonadST IO where instance MonadST (ST s) where withLiftST = \f -> f id +instance MonadST m => MonadST (ReaderT r m) where + withLiftST f = withLiftST $ \g -> f (lift . g) diff --git a/io-sim-classes/src/Control/Monad/Class/MonadSTM.hs b/io-sim-classes/src/Control/Monad/Class/MonadSTM.hs index 7fda21302ae..6bd900bf8d6 100644 --- a/io-sim-classes/src/Control/Monad/Class/MonadSTM.hs +++ b/io-sim-classes/src/Control/Monad/Class/MonadSTM.hs @@ -5,8 +5,13 @@ {-# LANGUAGE TypeFamilyDependencies #-} module Control.Monad.Class.MonadSTM ( MonadSTM (..) + , MonadSTMTx (..) , LazyTVar , LazyTMVar + , TVar + , TMVar + , TQueue + , TBQueue -- * Default 'TMVar' implementation , TMVarDefault (..) @@ -52,6 +57,7 @@ import qualified Control.Monad.STM as STM import Control.Applicative (Alternative (..)) import Control.Exception import Control.Monad (MonadPlus) +import Control.Monad.Reader import GHC.Stack import Numeric.Natural (Natural) @@ -61,113 +67,116 @@ import Numeric.Natural (Natural) type LazyTVar m = TVar m type LazyTMVar m = TMVar m -class ( Monad m - , Monad (STM m) - , Alternative (STM m) - , MonadPlus (STM m) - ) => MonadSTM m where +-- The STM primitives +class ( Monad stm + , Alternative stm + , MonadPlus stm + ) => MonadSTMTx stm where + type TVar_ stm :: * -> * - -- STM transactions - type STM m = (n :: * -> *) | n -> m - -- The STM primitives - type TVar m :: * -> * - - atomically :: HasCallStack => STM m a -> m a - newTVar :: a -> STM m (TVar m a) - readTVar :: TVar m a -> STM m a - writeTVar :: TVar m a -> a -> STM m () - retry :: STM m a - orElse :: STM m a -> STM m a -> STM m a - - -- Helpful derived functions with default implementations - newTVarM :: a -> m (TVar m a) - newTVarM = atomically . newTVar + newTVar :: a -> stm (TVar_ stm a) + readTVar :: TVar_ stm a -> stm a + writeTVar :: TVar_ stm a -> a -> stm () + retry :: stm a + orElse :: stm a -> stm a -> stm a - modifyTVar :: TVar m a -> (a -> a) -> STM m () + modifyTVar :: TVar_ stm a -> (a -> a) -> stm () modifyTVar v f = readTVar v >>= writeTVar v . f - modifyTVar' :: TVar m a -> (a -> a) -> STM m () + modifyTVar' :: TVar_ stm a -> (a -> a) -> stm () modifyTVar' v f = readTVar v >>= \x -> writeTVar v $! f x - check :: Bool -> STM m () + check :: Bool -> stm () check True = return () check _ = retry -- Additional derived STM APIs - type TMVar m :: * -> * - newTMVar :: a -> STM m (TMVar m a) - newTMVarM :: a -> m (TMVar m a) - newEmptyTMVar :: STM m (TMVar m a) - newEmptyTMVarM :: m (TMVar m a) - takeTMVar :: TMVar m a -> STM m a - tryTakeTMVar :: TMVar m a -> STM m (Maybe a) - putTMVar :: TMVar m a -> a -> STM m () - tryPutTMVar :: TMVar m a -> a -> STM m Bool - readTMVar :: TMVar m a -> STM m a - tryReadTMVar :: TMVar m a -> STM m (Maybe a) - swapTMVar :: TMVar m a -> a -> STM m a - isEmptyTMVar :: TMVar m a -> STM m Bool - - type TQueue m :: * -> * - newTQueue :: STM m (TQueue m a) - readTQueue :: TQueue m a -> STM m a - tryReadTQueue :: TQueue m a -> STM m (Maybe a) - writeTQueue :: TQueue m a -> a -> STM m () - isEmptyTQueue :: TQueue m a -> STM m Bool - - type TBQueue m :: * -> * - newTBQueue :: Natural -> STM m (TBQueue m a) - readTBQueue :: TBQueue m a -> STM m a - tryReadTBQueue :: TBQueue m a -> STM m (Maybe a) - writeTBQueue :: TBQueue m a -> a -> STM m () - isEmptyTBQueue :: TBQueue m a -> STM m Bool - isFullTBQueue :: TBQueue m a -> STM m Bool + type TMVar_ stm :: * -> * + newTMVar :: a -> stm (TMVar_ stm a) + newEmptyTMVar :: stm (TMVar_ stm a) + takeTMVar :: TMVar_ stm a -> stm a + tryTakeTMVar :: TMVar_ stm a -> stm (Maybe a) + putTMVar :: TMVar_ stm a -> a -> stm () + tryPutTMVar :: TMVar_ stm a -> a -> stm Bool + readTMVar :: TMVar_ stm a -> stm a + tryReadTMVar :: TMVar_ stm a -> stm (Maybe a) + swapTMVar :: TMVar_ stm a -> a -> stm a + isEmptyTMVar :: TMVar_ stm a -> stm Bool + + type TQueue_ stm :: * -> * + newTQueue :: stm (TQueue_ stm a) + readTQueue :: TQueue_ stm a -> stm a + tryReadTQueue :: TQueue_ stm a -> stm (Maybe a) + writeTQueue :: TQueue_ stm a -> a -> stm () + isEmptyTQueue :: TQueue_ stm a -> stm Bool + + type TBQueue_ stm :: * -> * + newTBQueue :: Natural -> stm (TBQueue_ stm a) + readTBQueue :: TBQueue_ stm a -> stm a + tryReadTBQueue :: TBQueue_ stm a -> stm (Maybe a) + writeTBQueue :: TBQueue_ stm a -> a -> stm () + isEmptyTBQueue :: TBQueue_ stm a -> stm Bool + isFullTBQueue :: TBQueue_ stm a -> stm Bool + +type TVar m = TVar_ (STM m) +type TMVar m = TMVar_ (STM m) +type TQueue m = TQueue_ (STM m) +type TBQueue m = TBQueue_ (STM m) + +class (Monad m, MonadSTMTx (STM m)) => MonadSTM m where + -- STM transactions + type STM m :: * -> * + + atomically :: HasCallStack => STM m a -> m a + + -- Helpful derived functions with default implementations + + newTVarM :: a -> m (TVar m a) + newTMVarM :: a -> m (TMVar m a) + newEmptyTMVarM :: m (TMVar m a) + newTVarM = atomically . newTVar + newTMVarM = atomically . newTMVar + newEmptyTMVarM = atomically newEmptyTMVar -- -- Instance for IO uses the existing STM library implementations -- -instance MonadSTM IO where - type STM IO = STM.STM - type TVar IO = STM.TVar - - atomically = wrapBlockedIndefinitely . STM.atomically - newTVar = STM.newTVar - readTVar = STM.readTVar - writeTVar = STM.writeTVar - retry = STM.retry - orElse = STM.orElse - - newTVarM = STM.newTVarIO - modifyTVar = STM.modifyTVar - modifyTVar' = STM.modifyTVar' - check = STM.check - - type TMVar IO = STM.TMVar - - newTMVar = STM.newTMVar - newTMVarM = STM.newTMVarIO - newEmptyTMVar = STM.newEmptyTMVar - newEmptyTMVarM = STM.newEmptyTMVarIO - takeTMVar = STM.takeTMVar - tryTakeTMVar = STM.tryTakeTMVar - putTMVar = STM.putTMVar - tryPutTMVar = STM.tryPutTMVar - readTMVar = STM.readTMVar - tryReadTMVar = STM.tryReadTMVar - swapTMVar = STM.swapTMVar - isEmptyTMVar = STM.isEmptyTMVar - - type TQueue IO = STM.TQueue - - newTQueue = STM.newTQueue - readTQueue = STM.readTQueue - tryReadTQueue = STM.tryReadTQueue - writeTQueue = STM.writeTQueue - isEmptyTQueue = STM.isEmptyTQueue - - type TBQueue IO = STM.TBQueue +instance MonadSTMTx STM.STM where + type TVar_ STM.STM = STM.TVar + type TMVar_ STM.STM = STM.TMVar + type TQueue_ STM.STM = STM.TQueue + type TBQueue_ STM.STM = STM.TBQueue + + newTVar = STM.newTVar + readTVar = STM.readTVar + writeTVar = STM.writeTVar + retry = STM.retry + orElse = STM.orElse + modifyTVar = STM.modifyTVar + modifyTVar' = STM.modifyTVar' + check = STM.check + newTMVar = STM.newTMVar + newEmptyTMVar = STM.newEmptyTMVar + takeTMVar = STM.takeTMVar + tryTakeTMVar = STM.tryTakeTMVar + putTMVar = STM.putTMVar + tryPutTMVar = STM.tryPutTMVar + readTMVar = STM.readTMVar + tryReadTMVar = STM.tryReadTMVar + swapTMVar = STM.swapTMVar + isEmptyTMVar = STM.isEmptyTMVar + newTQueue = STM.newTQueue + readTQueue = STM.readTQueue + tryReadTQueue = STM.tryReadTQueue + writeTQueue = STM.writeTQueue + isEmptyTQueue = STM.isEmptyTQueue + readTBQueue = STM.readTBQueue + tryReadTBQueue = STM.tryReadTBQueue + writeTBQueue = STM.writeTBQueue + isEmptyTBQueue = STM.isEmptyTBQueue + isFullTBQueue = STM.isFullTBQueue #if MIN_VERSION_stm(2,5,0) newTBQueue = STM.newTBQueue @@ -175,12 +184,15 @@ instance MonadSTM IO where -- STM prior to 2.5.0 takes an Int newTBQueue = STM.newTBQueue . fromEnum #endif - readTBQueue = STM.readTBQueue - tryReadTBQueue = STM.tryReadTBQueue - writeTBQueue = STM.writeTBQueue - isEmptyTBQueue = STM.isEmptyTBQueue - isFullTBQueue = STM.isFullTBQueue +instance MonadSTM IO where + type STM IO = STM.STM + + atomically = wrapBlockedIndefinitely . STM.atomically + + newTVarM = STM.newTVarIO + newTMVarM = STM.newTMVarIO + newEmptyTMVarM = STM.newEmptyTMVarIO -- | Wrapper around 'BlockedIndefinitelyOnSTM' that stores a call stack data BlockedIndefinitely = BlockedIndefinitely { @@ -198,6 +210,16 @@ instance Exception BlockedIndefinitely where wrapBlockedIndefinitely :: HasCallStack => IO a -> IO a wrapBlockedIndefinitely = handle (throwIO . BlockedIndefinitely callStack) +-- +-- Lift to monad transformers +-- + +instance MonadSTM m => MonadSTM (ReaderT r m) where + type STM (ReaderT r m) = STM m + atomically = lift . atomically + newTVarM = lift . newTVarM + newTMVarM = lift . newTMVarM + newEmptyTMVarM = lift newEmptyTMVarM -- -- Default TMVar implementation in terms of TVars (used by sim) @@ -277,7 +299,6 @@ isEmptyTMVarDefault (TMVar t) = do Nothing -> return True Just _ -> return False - -- -- Default TQueue implementation in terms of TVars (used by sim) -- diff --git a/io-sim-classes/src/Control/Monad/Class/MonadSTM/Strict.hs b/io-sim-classes/src/Control/Monad/Class/MonadSTM/Strict.hs index e212ecde5c7..531b9f41d92 100644 --- a/io-sim-classes/src/Control/Monad/Class/MonadSTM/Strict.hs +++ b/io-sim-classes/src/Control/Monad/Class/MonadSTM/Strict.hs @@ -2,12 +2,14 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE TypeFamilies #-} module Control.Monad.Class.MonadSTM.Strict ( module X , LazyTVar , LazyTMVar -- * 'StrictTVar' , StrictTVar + , castStrictTVar , toLazyTVar , newTVar , newTVarM @@ -18,6 +20,7 @@ module Control.Monad.Class.MonadSTM.Strict , updateTVar -- * 'StrictTMVar' , StrictTMVar + , castStrictTMVar , newTMVar , newTMVarM , newEmptyTMVar @@ -34,12 +37,11 @@ module Control.Monad.Class.MonadSTM.Strict , checkInvariant ) where -import Control.Monad.Class.MonadSTM as X hiding - (TVar, TMVar, LazyTVar, LazyTMVar, - isEmptyTMVar, modifyTVar, newEmptyTMVar, newEmptyTMVarM, - newTMVar, newTMVarM, newTVar, newTVarM, putTMVar, - readTMVar, readTVar, swapTMVar, takeTMVar, tryPutTMVar, - tryReadTMVar, tryTakeTMVar, writeTVar) +import Control.Monad.Class.MonadSTM as X hiding (LazyTMVar, LazyTVar, + TMVar, TVar, isEmptyTMVar, modifyTVar, newEmptyTMVar, + newEmptyTMVarM, newTMVar, newTMVarM, newTVar, newTVarM, + putTMVar, readTMVar, readTVar, swapTMVar, takeTMVar, + tryPutTMVar, tryReadTMVar, tryTakeTMVar, writeTVar) import qualified Control.Monad.Class.MonadSTM as Lazy import GHC.Stack @@ -60,6 +62,10 @@ data StrictTVar m a = StrictTVar , tvar :: !(LazyTVar m a) } +castStrictTVar :: LazyTVar m ~ LazyTVar n + => StrictTVar m a -> StrictTVar n a +castStrictTVar StrictTVar{invariant, tvar} = StrictTVar{invariant, tvar} + -- | Get the underlying @TVar@ -- -- Since we obviously cannot guarantee that updates to this 'LazyTVar' will be @@ -110,6 +116,10 @@ updateTVar v f = do -- to very hard to debug bugs where code is blocked indefinitely. newtype StrictTMVar m a = StrictTMVar (LazyTMVar m a) +castStrictTMVar :: LazyTMVar m ~ LazyTMVar n + => StrictTMVar m a -> StrictTMVar n a +castStrictTMVar (StrictTMVar var) = StrictTMVar var + newTMVar :: MonadSTM m => a -> STM m (StrictTMVar m a) newTMVar !a = StrictTMVar <$> Lazy.newTMVar a diff --git a/io-sim-classes/src/Control/Monad/Class/MonadTime.hs b/io-sim-classes/src/Control/Monad/Class/MonadTime.hs index 0d23461ea75..dc491d1ef82 100644 --- a/io-sim-classes/src/Control/Monad/Class/MonadTime.hs +++ b/io-sim-classes/src/Control/Monad/Class/MonadTime.hs @@ -8,10 +8,10 @@ module Control.Monad.Class.MonadTime ( , UTCTime ) where -import Data.Word (Word64) +import Control.Monad.Reader import Data.Time.Clock (DiffTime, UTCTime) import qualified Data.Time.Clock as Time - +import Data.Word (Word64) -- | A point in time in a monotonic clock. -- @@ -59,3 +59,10 @@ instance MonadTime IO where foreign import ccall unsafe "getMonotonicNSec" getMonotonicNSec :: IO Word64 +-- +-- Instance for ReaderT +-- + +instance MonadTime m => MonadTime (ReaderT r m) where + getMonotonicTime = lift getMonotonicTime + getCurrentTime = lift getCurrentTime diff --git a/io-sim-classes/src/Control/Monad/Class/MonadTimer.hs b/io-sim-classes/src/Control/Monad/Class/MonadTimer.hs index b147db00ee9..b599f116175 100644 --- a/io-sim-classes/src/Control/Monad/Class/MonadTimer.hs +++ b/io-sim-classes/src/Control/Monad/Class/MonadTimer.hs @@ -1,18 +1,20 @@ -{-# LANGUAGE CPP #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TypeFamilies #-} module Control.Monad.Class.MonadTimer ( - MonadTimer(..) + MonadDelay(..) + , MonadTimer(..) , TimeoutState(..) ) where import qualified Control.Concurrent as IO import qualified Control.Concurrent.STM.TVar as STM -import qualified Control.Monad.STM as STM import Control.Exception (assert) +import Control.Monad.Reader +import qualified Control.Monad.STM as STM import Data.Functor (void) import Data.Time.Clock (DiffTime, diffTimeToPicoseconds) @@ -28,10 +30,15 @@ import Control.Monad.Class.MonadSTM import qualified System.Timeout as IO - data TimeoutState = TimeoutPending | TimeoutFired | TimeoutCancelled -class MonadSTM m => MonadTimer m where +class Monad m => MonadDelay m where + threadDelay :: DiffTime -> m () + + default threadDelay :: MonadTimer m => DiffTime -> m () + threadDelay d = void . atomically . awaitTimeout =<< newTimeout d + +class (MonadSTM m, MonadDelay m) => MonadTimer m where data Timeout m :: * -- | Create a new timeout which will fire at the given time duration in @@ -85,9 +92,6 @@ class MonadSTM m => MonadTimer m where TimeoutFired -> return True TimeoutCancelled -> return False - threadDelay :: DiffTime -> m () - threadDelay d = void . atomically . awaitTimeout =<< newTimeout d - registerDelay :: DiffTime -> m (TVar m Bool) default registerDelay :: MonadFork m => DiffTime -> m (TVar m Bool) @@ -104,6 +108,8 @@ class MonadSTM m => MonadTimer m where -- Instances for IO -- +instance MonadDelay IO where + threadDelay d = IO.threadDelay (diffTimeToMicrosecondsAsInt d) #if defined(__GLASGOW_HASKELL__) && !defined(mingw32_HOST_OS) && !defined(__GHCJS__) instance MonadTimer IO where @@ -168,8 +174,6 @@ instance MonadTimer IO where when (not fired) $ STM.writeTVar cancelvar True #endif - threadDelay d = IO.threadDelay (diffTimeToMicrosecondsAsInt d) - registerDelay = STM.registerDelay . diffTimeToMicrosecondsAsInt timeout = IO.timeout . diffTimeToMicrosecondsAsInt @@ -183,3 +187,22 @@ diffTimeToMicrosecondsAsInt d = -- systems means 2^31 usec, which is only ~35 minutes. assert (usec <= fromIntegral (maxBound :: Int)) $ fromIntegral usec + +-- +-- Lift to ReaderT +-- + +instance MonadDelay m => MonadDelay (ReaderT r m) where + threadDelay = lift . threadDelay + +instance (MonadTimer m, MonadFork m) => MonadTimer (ReaderT r m) where + newtype Timeout (ReaderT r m) = WrapTimeoutReader { + unwrapTimeoutReader :: Timeout m + } + + newTimeout d = lift $ WrapTimeoutReader <$> newTimeout d + updateTimeout t = lift . updateTimeout (unwrapTimeoutReader t) + cancelTimeout t = lift $ cancelTimeout (unwrapTimeoutReader t) + + timeout d ma = ReaderT $ timeout d . runReaderT ma + readTimeout t = readTimeout $ unwrapTimeoutReader t diff --git a/io-sim/src/Control/Monad/IOSim.hs b/io-sim/src/Control/Monad/IOSim.hs index 767f9c5ef9d..c0362a21f0e 100644 --- a/io-sim/src/Control/Monad/IOSim.hs +++ b/io-sim/src/Control/Monad/IOSim.hs @@ -40,27 +40,25 @@ module Control.Monad.IOSim ( import Prelude hiding (read) +import Data.Dynamic (Dynamic, fromDynamic, toDyn) +import Data.Foldable (traverse_) import Data.Functor (void) -import Data.OrdPSQ (OrdPSQ) -import qualified Data.OrdPSQ as PSQ import qualified Data.List as List -import Data.Foldable (traverse_) -import qualified Data.Map.Strict as Map import Data.Map.Strict (Map) -import qualified Data.Set as Set +import qualified Data.Map.Strict as Map +import Data.OrdPSQ (OrdPSQ) +import qualified Data.OrdPSQ as PSQ import Data.Set (Set) +import qualified Data.Set as Set +import Data.Time (DiffTime, NominalDiffTime, UTCTime (..), addUTCTime, + diffUTCTime, fromGregorian) import Data.Typeable (Typeable) -import Data.Dynamic (Dynamic, toDyn, fromDynamic) -import Data.Time - ( DiffTime, NominalDiffTime, UTCTime(..) - , diffUTCTime, addUTCTime, fromGregorian ) - -import Control.Applicative (Applicative(..), Alternative(..)) -import Control.Monad (join, MonadPlus, mapM_) -import Control.Exception - ( Exception(..), SomeException - , ErrorCall(..), throw, assert - , asyncExceptionToException, asyncExceptionFromException ) + +import Control.Applicative (Alternative (..), Applicative (..)) +import Control.Exception (ErrorCall (..), Exception (..), + SomeException, assert, asyncExceptionFromException, + asyncExceptionToException, throw) +import Control.Monad (MonadPlus, join, mapM_) import qualified System.IO.Error as IO.Error (userError) import Control.Monad (when) @@ -68,18 +66,18 @@ import Control.Monad.ST.Lazy import qualified Control.Monad.ST.Strict as StrictST import Data.STRef.Lazy -import Control.Monad.Fail as MonadFail import qualified Control.Monad.Catch as Exceptions +import Control.Monad.Fail as MonadFail +import Control.Monad.Class.MonadAsync hiding (Async) +import qualified Control.Monad.Class.MonadAsync as MonadAsync import Control.Monad.Class.MonadFork hiding (ThreadId) import qualified Control.Monad.Class.MonadFork as MonadFork -import Control.Monad.Class.MonadThrow as MonadThrow import Control.Monad.Class.MonadSay import Control.Monad.Class.MonadST import Control.Monad.Class.MonadSTM hiding (STM, TVar) import qualified Control.Monad.Class.MonadSTM as MonadSTM -import Control.Monad.Class.MonadAsync hiding (Async) -import qualified Control.Monad.Class.MonadAsync as MonadAsync +import Control.Monad.Class.MonadThrow as MonadThrow import Control.Monad.Class.MonadTime import Control.Monad.Class.MonadTimer @@ -298,14 +296,11 @@ instance MonadFork (SimM s) where forkWithUnmask f = fork (f unblock) throwTo tid e = SimM $ \k -> ThrowTo (toException e) tid (k ()) -instance MonadSTM (SimM s) where - type STM (SimM s) = STM s - type TVar (SimM s) = TVar s - type TMVar (SimM s) = TMVarDefault (SimM s) - type TQueue (SimM s) = TQueueDefault (SimM s) - type TBQueue (SimM s) = TBQueueDefault (SimM s) - - atomically action = SimM $ \k -> Atomically action k +instance MonadSTMTx (STM s) where + type TVar_ (STM s) = TVar s + type TMVar_ (STM s) = TMVarDefault (SimM s) + type TQueue_ (STM s) = TQueueDefault (SimM s) + type TBQueue_ (STM s) = TBQueueDefault (SimM s) newTVar x = STM $ \k -> NewTVar x k readTVar tvar = STM $ \k -> ReadTVar tvar k @@ -314,9 +309,7 @@ instance MonadSTM (SimM s) where orElse a b = STM $ \k -> OrElse (runSTM a) (runSTM b) k newTMVar = newTMVarDefault - newTMVarM = newTMVarMDefault newEmptyTMVar = newEmptyTMVarDefault - newEmptyTMVarM = newEmptyTMVarMDefault takeTMVar = takeTMVarDefault tryTakeTMVar = tryTakeTMVarDefault putTMVar = putTMVarDefault @@ -339,6 +332,14 @@ instance MonadSTM (SimM s) where isEmptyTBQueue = isEmptyTBQueueDefault isFullTBQueue = isFullTBQueueDefault +instance MonadSTM (SimM s) where + type STM (SimM s) = STM s + + atomically action = SimM $ \k -> Atomically action k + + newTMVarM = newTMVarMDefault + newEmptyTMVarM = newEmptyTMVarMDefault + data Async s a = Async !ThreadId (STM s (Either SomeException a)) instance Eq (Async s a) where @@ -350,6 +351,10 @@ instance Ord (Async s a) where instance Functor (Async s) where fmap f (Async tid a) = Async tid (fmap f <$> a) +instance MonadAsyncSTM (Async s) (STM s) where + waitCatchSTM (Async _ w) = w + pollSTM (Async _ w) = (Just <$> w) `orElse` return Nothing + instance MonadAsync (SimM s) where type Async (SimM s) = Async s @@ -364,9 +369,6 @@ instance MonadAsync (SimM s) where cancel a@(Async tid _) = throwTo tid AsyncCancelled <* waitCatch a cancelWith a@(Async tid _) e = throwTo tid e <* waitCatch a - waitCatchSTM (Async _ w) = w - pollSTM (Async _ w) = (Just <$> w) `orElse` return Nothing - instance MonadST (SimM s) where withLiftST f = f liftST @@ -391,6 +393,9 @@ setCurrentTime t = SimM $ \k -> SetWallTime t (k ()) unshareClock :: SimM s () unshareClock = SimM $ \k -> UnshareClock (k ()) +instance MonadDelay (SimM s) where + -- Use default in terms of MonadTimer + instance MonadTimer (SimM s) where data Timeout (SimM s) = Timeout !(TVar s TimeoutState) !TimeoutId @@ -1159,7 +1164,7 @@ data TVar s a = TVar { -- | The identifier of this var. -- - tvarId :: !TVarId, + tvarId :: !TVarId, -- | The var's current value -- @@ -1200,7 +1205,7 @@ data SomeTVar s where data StmStack s b a where -- | Executing in the context of a top level 'atomically'. AtomicallyFrame :: StmStack s a a - + -- | Executing in the context of the /left/ hand side of an 'orElse' OrElseLeftFrame :: StmA s a -- orElse right alternative -> (a -> StmA s b) -- subsequent continuation diff --git a/io-sim/test/Test/IOSim.hs b/io-sim/test/Test/IOSim.hs index 7d84ace7b19..527d837de65 100644 --- a/io-sim/test/Test/IOSim.hs +++ b/io-sim/test/Test/IOSim.hs @@ -1,6 +1,6 @@ {-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} module Test.IOSim ( tests @@ -14,16 +14,15 @@ import Data.Graph import Data.List (sort) import Data.Time.Clock (DiffTime, picosecondsToDiffTime) +import Control.Exception (ArithException (..)) import Control.Monad -import Control.Exception - ( ArithException(..) ) -import System.IO.Error (IOError, isUserError, ioeGetErrorString) +import System.IO.Error (IOError, ioeGetErrorString, isUserError) import Control.Monad.Class.MonadFork +import Control.Monad.Class.MonadSay import Control.Monad.Class.MonadSTM.Strict import Control.Monad.Class.MonadThrow import Control.Monad.Class.MonadTimer -import Control.Monad.Class.MonadSay import Control.Monad.IOSim import Test.STM @@ -104,7 +103,7 @@ prop_stm_graph_sim g = prop_stm_graph :: (MonadFork m, MonadSTM m) => TestThreadGraph -> m () prop_stm_graph (TestThreadGraph g) = do vars <- listArray (bounds g) <$> - sequence [ atomically (newTVar False) | _ <- vertices g ] + sequence [ newTVarM False | _ <- vertices g ] forM_ (vertices g) $ \v -> void $ fork $ do -- read all the inputs and wait for them to become true @@ -219,7 +218,7 @@ test_timers xs = experiment :: Probe m (DiffTime, Int) -> m () experiment p = do tvars <- forM (zip xs [0..]) $ \(t, idx) -> do - v <- atomically $ newTVar False + v <- newTVarM False void $ fork $ threadDelay t >> do probeOutput p (t, idx) atomically $ writeTVar v True @@ -263,7 +262,7 @@ test_fork_order = \(Positive n) -> isValid n <$> withProbe (experiment n) experiment :: Int -> Probe m Int -> m () experiment 0 _ = return () experiment n p = do - v <- atomically $ newTVar False + v <- newTVarM False void $ fork $ do probeOutput p n @@ -295,7 +294,7 @@ test_threadId_order = \(Positive n) -> do where experiment :: m (ThreadId m) experiment = do - v <- atomically $ newTVar False + v <- newTVarM False tid <- fork $ atomically $ writeTVar v True @@ -326,7 +325,7 @@ test_wakeup_order :: ( MonadFork m ) => m Property test_wakeup_order = do - v <- atomically $ newTVar False + v <- newTVarM False wakupOrder <- withProbe $ \p -> do sequence_ diff --git a/io-sim/test/Test/STM.hs b/io-sim/test/Test/STM.hs index a097a3aa20c..5502febedcb 100644 --- a/io-sim/test/Test/STM.hs +++ b/io-sim/test/Test/STM.hs @@ -1,15 +1,15 @@ -{-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE TypeOperators #-} {-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} {-# OPTIONS_GHC -Wno-unticked-promoted-constructors #-} @@ -22,16 +22,16 @@ -- module Test.STM where -import Prelude hiding (exp) -import Data.Type.Equality -import Data.Maybe (fromMaybe, maybeToList) -import qualified Data.Map.Strict as Map import Data.Map.Strict (Map) -import qualified Data.Set as Set +import qualified Data.Map.Strict as Map +import Data.Maybe (fromMaybe, maybeToList) import Data.Set (Set) +import qualified Data.Set as Set +import Data.Type.Equality +import Prelude hiding (exp) -import Control.Monad.Class.MonadThrow import Control.Monad.Class.MonadSTM as STM +import Control.Monad.Class.MonadThrow import Test.QuickCheck @@ -490,12 +490,12 @@ snapshotExecValue (ExecValInt x) = return (ImmValInt x) snapshotExecValue (ExecValVar v _) = fmap ImmValVar (snapshotExecValue =<< readTVar v) -execAtomically :: (MonadSTM m, MonadThrow (STM m), MonadCatch m) +execAtomically :: forall m t. (MonadSTM m, MonadThrow (STM m), MonadCatch m) => Term t -> m TxResult execAtomically t = toTxResult <$> try (atomically action') where - action = snapshotExecValue =<< execTerm mempty t + action = snapshotExecValue =<< execTerm (mempty :: ExecEnv m) t action' = fmap Just action `orElse` return Nothing -- We want to observe if the transaction would block. If we trust the STM @@ -534,7 +534,7 @@ data GenEnv = GenEnv { envNames :: TyTrie NameId, -- | For managing the fresh name supply - envNextName :: NameId + envNextName :: NameId } data TyTrie a = diff --git a/ouroboros-consensus/src/Ouroboros/Consensus/NodeKernel.hs b/ouroboros-consensus/src/Ouroboros/Consensus/NodeKernel.hs index 5dc5458839e..804bc22dc6c 100644 --- a/ouroboros-consensus/src/Ouroboros/Consensus/NodeKernel.hs +++ b/ouroboros-consensus/src/Ouroboros/Consensus/NodeKernel.hs @@ -494,9 +494,12 @@ forkBlockProduction maxBlockSizeOverride IS{..} BlockProduction{..} = noOverride = nodeMaxBlockSize ledger - blockEncOverhead runProtocol :: StrictTVar m PRNG -> ProtocolM blk m a -> STM m a - runProtocol varDRG = simOuroborosStateT varState - $ simChaChaT varDRG - $ id + runProtocol varDRG = runSim sim + where + sim :: Sim (NodeStateT (BlockProtocol blk) (ChaChaT (STM m))) m + sim = simOuroborosStateT varState + $ simChaChaT varDRG + $ simId -- | State of the pseudo-random number generator newtype PRNG = PRNG ChaChaDRG diff --git a/ouroboros-consensus/src/Ouroboros/Consensus/Util/EarlyExit.hs b/ouroboros-consensus/src/Ouroboros/Consensus/Util/EarlyExit.hs index dda1b50c452..abd36d33f6e 100644 --- a/ouroboros-consensus/src/Ouroboros/Consensus/Util/EarlyExit.hs +++ b/ouroboros-consensus/src/Ouroboros/Consensus/Util/EarlyExit.hs @@ -1,5 +1,6 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -75,74 +76,49 @@ instance (forall a'. NoUnexpectedThunks (m a')) whnfNoUnexpectedThunks ctxt = whnfNoUnexpectedThunks ctxt . withEarlyExit showTypeOf _p = "WithEarlyExit " ++ showTypeOf (Proxy @(m a)) -{------------------------------------------------------------------------------- - Special wrapper for STM - - This is required because MonadSTM requires STM to be injective. --------------------------------------------------------------------------------} - -newtype WrapSTM m a = Wrap { unwrap :: WithEarlyExit (STM m) a } - -unwrapSTM :: WrapSTM m a -> STM m (Maybe a) -unwrapSTM = withEarlyExit . unwrap - -wrapSTM :: STM m (Maybe a) -> WrapSTM m a -wrapSTM = Wrap . earlyExit - -wrapSTM' :: MonadSTM m => STM m a -> WrapSTM m a -wrapSTM' = wrapSTM . fmap Just - -deriving instance MonadSTM m => Functor (WrapSTM m) -deriving instance MonadSTM m => Applicative (WrapSTM m) -deriving instance MonadSTM m => Monad (WrapSTM m) -deriving instance MonadSTM m => Alternative (WrapSTM m) -deriving instance MonadSTM m => MonadPlus (WrapSTM m) - --- These two piggy-back on the instances for WithEarlyExit, below -deriving instance (MonadSTM m, MonadCatch (STM m)) => MonadThrow (WrapSTM m) -deriving instance (MonadSTM m, MonadCatch (STM m)) => MonadCatch (WrapSTM m) - {------------------------------------------------------------------------------- Instances for io-classes -------------------------------------------------------------------------------} +instance MonadSTMTx stm => MonadSTMTx (WithEarlyExit stm) where + type TVar_ (WithEarlyExit stm) = TVar_ stm + type TMVar_ (WithEarlyExit stm) = TMVar_ stm + type TQueue_ (WithEarlyExit stm) = TQueue_ stm + type TBQueue_ (WithEarlyExit stm) = TBQueue_ stm + + newTVar = lift . newTVar + readTVar = lift . readTVar + writeTVar = lift .: writeTVar + retry = lift retry + orElse = (earlyExit .: orElse) `on` withEarlyExit + newTMVar = lift . newTMVar + newEmptyTMVar = lift newEmptyTMVar + takeTMVar = lift . takeTMVar + tryTakeTMVar = lift . tryTakeTMVar + putTMVar = lift .: putTMVar + tryPutTMVar = lift .: tryPutTMVar + readTMVar = lift . readTMVar + tryReadTMVar = lift . tryReadTMVar + swapTMVar = lift .: swapTMVar + isEmptyTMVar = lift . isEmptyTMVar + newTQueue = lift newTQueue + readTQueue = lift . readTQueue + tryReadTQueue = lift . tryReadTQueue + writeTQueue = lift .: writeTQueue + isEmptyTQueue = lift . isEmptyTQueue + newTBQueue = lift . newTBQueue + readTBQueue = lift . readTBQueue + tryReadTBQueue = lift . tryReadTBQueue + writeTBQueue = lift .: writeTBQueue + isEmptyTBQueue = lift . isEmptyTBQueue + isFullTBQueue = lift . isFullTBQueue + instance MonadSTM m => MonadSTM (WithEarlyExit m) where - type STM (WithEarlyExit m) = WrapSTM m -- == WithEarlyExit (STM m) - type TVar (WithEarlyExit m) = TVar m - type TMVar (WithEarlyExit m) = TMVar m - type TQueue (WithEarlyExit m) = TQueue m - type TBQueue (WithEarlyExit m) = TBQueue m - - atomically = earlyExit . atomically . unwrapSTM - - newTVar = wrapSTM' . newTVar - readTVar = wrapSTM' . readTVar - writeTVar = wrapSTM' .: writeTVar - retry = wrapSTM' retry - orElse = (wrapSTM .: orElse) `on` unwrapSTM - newTMVar = wrapSTM' . newTMVar - newTMVarM = lift . newTMVarM - newEmptyTMVar = wrapSTM' newEmptyTMVar - newEmptyTMVarM = lift newEmptyTMVarM - takeTMVar = wrapSTM' . takeTMVar - tryTakeTMVar = wrapSTM' . tryTakeTMVar - putTMVar = wrapSTM' .: putTMVar - tryPutTMVar = wrapSTM' .: tryPutTMVar - readTMVar = wrapSTM' . readTMVar - tryReadTMVar = wrapSTM' . tryReadTMVar - swapTMVar = wrapSTM' .: swapTMVar - isEmptyTMVar = wrapSTM' . isEmptyTMVar - newTQueue = wrapSTM' newTQueue - readTQueue = wrapSTM' . readTQueue - tryReadTQueue = wrapSTM' . tryReadTQueue - writeTQueue = wrapSTM' .: writeTQueue - isEmptyTQueue = wrapSTM' . isEmptyTQueue - newTBQueue = wrapSTM' . newTBQueue - readTBQueue = wrapSTM' . readTBQueue - tryReadTBQueue = wrapSTM' . tryReadTBQueue - writeTBQueue = wrapSTM' .: writeTBQueue - isEmptyTBQueue = wrapSTM' . isEmptyTBQueue - isFullTBQueue = wrapSTM' . isFullTBQueue + type STM (WithEarlyExit m) = WithEarlyExit (STM m) + + atomically = earlyExit . atomically . withEarlyExit + newTMVarM = lift . newTMVarM + newEmptyTMVarM = lift newEmptyTMVarM instance MonadCatch m => MonadThrow (WithEarlyExit m) where throwM = lift . throwM @@ -188,18 +164,19 @@ instance MonadThread m => MonadThread (WithEarlyExit m) where myThreadId = lift myThreadId labelThread = lift .: labelThread -instance ( MonadMask m - , MonadAsync m - , MonadCatch (STM m) - ) => MonadAsync (WithEarlyExit m) where +instance (MonadAsyncSTM async stm, MonadCatch stm) + => MonadAsyncSTM (WithEarlyExit async) (WithEarlyExit stm) where + waitCatchSTM a = earlyExit (commute <$> waitCatchSTM (withEarlyExit a)) + pollSTM a = earlyExit (fmap commute <$> pollSTM (withEarlyExit a)) + +instance (MonadMask m, MonadAsync m, MonadCatch (STM m)) + => MonadAsync (WithEarlyExit m) where type Async (WithEarlyExit m) = WithEarlyExit (Async m) async = lift . (fmap earlyExit . async) . withEarlyExit asyncThreadId _p = asyncThreadId (Proxy @(WithEarlyExit m)) cancel a = lift $ cancel (withEarlyExit a) cancelWith a = lift . cancelWith (withEarlyExit a) - waitCatchSTM a = wrapSTM (commute <$> waitCatchSTM (withEarlyExit a)) - pollSTM a = wrapSTM (fmap commute <$> pollSTM (withEarlyExit a)) commute :: Either SomeException (Maybe a) -> Maybe (Either SomeException a) commute (Left e) = Just (Left e) @@ -227,14 +204,16 @@ instance MonadTime m => MonadTime (WithEarlyExit m) where getMonotonicTime = lift getMonotonicTime getCurrentTime = lift getCurrentTime +instance MonadDelay m => MonadDelay (WithEarlyExit m) where + threadDelay = lift . threadDelay + instance (MonadTimer m, MonadFork m) => MonadTimer (WithEarlyExit m) where newtype Timeout (WithEarlyExit m) = WrapTimeout { unwrapTimeout :: Timeout m } - threadDelay = lift . threadDelay - newTimeout d = lift $ WrapTimeout <$> newTimeout d - readTimeout t = wrapSTM' $ readTimeout (unwrapTimeout t) - updateTimeout t = lift . updateTimeout (unwrapTimeout t) - cancelTimeout t = lift $ cancelTimeout (unwrapTimeout t) + newTimeout d = lift $ WrapTimeout <$> newTimeout d + readTimeout t = lift $ readTimeout (unwrapTimeout t) + updateTimeout t = lift . updateTimeout (unwrapTimeout t) + cancelTimeout t = lift $ cancelTimeout (unwrapTimeout t) timeout d = earlyExit . timeout d . withEarlyExit {------------------------------------------------------------------------------- diff --git a/ouroboros-consensus/src/Ouroboros/Consensus/Util/IOLike.hs b/ouroboros-consensus/src/Ouroboros/Consensus/Util/IOLike.hs index c848c24a60f..e65f21739ae 100644 --- a/ouroboros-consensus/src/Ouroboros/Consensus/Util/IOLike.hs +++ b/ouroboros-consensus/src/Ouroboros/Consensus/Util/IOLike.hs @@ -31,6 +31,7 @@ module Ouroboros.Consensus.Util.IOLike ( , diffTime -- *** MonadDelay , MonadTimer(..) + , MonadDelay(..) -- *** Cardano prelude , NoUnexpectedThunks(..) ) where diff --git a/ouroboros-consensus/src/Ouroboros/Consensus/Util/STM.hs b/ouroboros-consensus/src/Ouroboros/Consensus/Util/STM.hs index 1fc7beb82de..5eec07ba592 100644 --- a/ouroboros-consensus/src/Ouroboros/Consensus/Util/STM.hs +++ b/ouroboros-consensus/src/Ouroboros/Consensus/Util/STM.hs @@ -17,7 +17,7 @@ module Ouroboros.Consensus.Util.STM ( , Fingerprint (..) , WithFingerprint (..) -- * Simulate various monad stacks in STM - , Sim + , Sim(..) , simId , simStateT , simOuroborosStateT @@ -41,8 +41,8 @@ import Ouroboros.Consensus.Util.ResourceRegistry -------------------------------------------------------------------------------} -- | Wait until the TVar changed -blockUntilChanged :: forall m a b. (IOLike m, Eq b) - => (a -> b) -> b -> STM m a -> STM m (a, b) +blockUntilChanged :: forall stm a b. (MonadSTMTx stm, Eq b) + => (a -> b) -> b -> stm a -> stm (a, b) blockUntilChanged f b getA = do a <- getA let b' = f a @@ -97,14 +97,14 @@ runWhenJust registry getMaybeA action = void $ forkLinkedThread registry $ action =<< atomically (blockUntilJust getMaybeA) -blockUntilJust :: IOLike m => STM m (Maybe a) -> STM m a +blockUntilJust :: MonadSTMTx stm => stm (Maybe a) -> stm a blockUntilJust getMaybeA = do ma <- getMaybeA case ma of Nothing -> retry Just a -> return a -blockUntilAllJust :: IOLike m => [STM m (Maybe a)] -> STM m [a] +blockUntilAllJust :: MonadSTMTx stm => [stm (Maybe a)] -> stm [a] blockUntilAllJust = mapM blockUntilJust -- | Simple type that can be used to indicate something in a @TVar@ is @@ -124,13 +124,13 @@ data WithFingerprint a = WithFingerprint Simulate monad stacks -------------------------------------------------------------------------------} -type Sim n m = forall a. n a -> STM m a +newtype Sim n m = Sim { runSim :: forall a. n a -> STM m a } simId :: Sim (STM m) m -simId = id +simId = Sim id simStateT :: IOLike m => StrictTVar m st -> Sim n m -> Sim (StateT st n) m -simStateT stVar k (StateT f) = do +simStateT stVar (Sim k) = Sim $ \(StateT f) -> do st <- readTVar stVar (a, st') <- k (f st) writeTVar stVar st' @@ -140,7 +140,7 @@ simOuroborosStateT :: IOLike m => StrictTVar m s -> Sim n m -> Sim (NodeStateT_ s n) m -simOuroborosStateT stVar k n = do +simOuroborosStateT stVar (Sim k) = Sim $ \n -> do st <- readTVar stVar (a, st') <- k (runNodeStateT n st) writeTVar stVar st' @@ -150,7 +150,7 @@ simChaChaT :: (IOLike m, Coercible a ChaChaDRG) => StrictTVar m a -> Sim n m -> Sim (ChaChaT n) m -simChaChaT stVar k n = do +simChaChaT stVar (Sim k) = Sim $ \n -> do st <- readTVar stVar (a, st') <- k (runChaChaT n (coerce st)) writeTVar stVar (coerce st') diff --git a/ouroboros-consensus/src/Ouroboros/Storage/ChainDB/Impl/Reader.hs b/ouroboros-consensus/src/Ouroboros/Storage/ChainDB/Impl/Reader.hs index 96184d2fda1..1e8008337ba 100644 --- a/ouroboros-consensus/src/Ouroboros/Storage/ChainDB/Impl/Reader.hs +++ b/ouroboros-consensus/src/Ouroboros/Storage/ChainDB/Impl/Reader.hs @@ -378,14 +378,14 @@ instructionHelper registry varReader blockComponent encodeHeader fromMaybeSTM CD -- | 'readerInstruction' for when the reader is in the 'ReaderInMem' state. instructionSTM - :: forall m blk. (IOLike m, HasHeader (Header blk)) + :: forall stm blk. (MonadSTMTx stm, HasHeader (Header blk)) => ReaderRollState blk -- ^ The current 'ReaderRollState' of the reader -> AnchoredFragment (Header blk) -- ^ The current chain fragment - -> (ReaderRollState blk -> STM m ()) + -> (ReaderRollState blk -> stm ()) -- ^ How to save the updated 'ReaderRollState' - -> STM m (Maybe (ChainUpdate blk (Header blk))) + -> stm (Maybe (ChainUpdate blk (Header blk))) instructionSTM rollState curChain saveRollState = assert (invariant curChain) $ case rollState of RollForwardFrom pt -> diff --git a/ouroboros-consensus/src/Ouroboros/Storage/Util/ErrorHandling.hs b/ouroboros-consensus/src/Ouroboros/Storage/Util/ErrorHandling.hs index e4944ef916d..b598e1d1665 100644 --- a/ouroboros-consensus/src/Ouroboros/Storage/Util/ErrorHandling.hs +++ b/ouroboros-consensus/src/Ouroboros/Storage/Util/ErrorHandling.hs @@ -49,7 +49,6 @@ import Data.Void import Cardano.Prelude (NoUnexpectedThunks (..), OnlyCheckIsWHNF (..)) -import Control.Monad.Class.MonadSTM import Control.Monad.Class.MonadThrow (MonadCatch) import qualified Control.Monad.Class.MonadThrow as C @@ -185,7 +184,7 @@ data ThrowCantCatch e m = ThrowCantCatch { throwCantCatch :: ErrorHandling e m -> ThrowCantCatch e m throwCantCatch ErrorHandling{..} = ThrowCantCatch throwError -throwSTM :: (C.MonadThrow (STM m), Exception e) => ThrowCantCatch e (STM m) +throwSTM :: (C.MonadThrow m, Exception e) => ThrowCantCatch e m throwSTM = ThrowCantCatch $ C.throwM {------------------------------------------------------------------------------- diff --git a/ouroboros-consensus/test-consensus/Test/ThreadNet/Network.hs b/ouroboros-consensus/test-consensus/Test/ThreadNet/Network.hs index daced878828..4c6b35a8ce5 100644 --- a/ouroboros-consensus/test-consensus/Test/ThreadNet/Network.hs +++ b/ouroboros-consensus/test-consensus/Test/ThreadNet/Network.hs @@ -450,7 +450,8 @@ runThreadNetwork ThreadNetworkArgs varDRG <- uncheckedNewTVarM =<< produceDRG txs <- atomically $ do ledger <- ledgerState <$> getExtLedger - simChaChaT varDRG id $ testGenTxs numCoreNodes curSlotNo cfg ledger + runSim (simChaChaT varDRG simId) $ + testGenTxs numCoreNodes curSlotNo cfg ledger void $ addTxs mempool txs forkEbbProducer :: HasCallStack @@ -581,7 +582,8 @@ runThreadNetwork ThreadNetworkArgs let blockProduction :: BlockProduction m blk blockProduction = BlockProduction { produceBlock = nodeForgeBlock pInfoConfig - , produceDRG = atomically $ simChaChaT varRNG id $ drgNew + , produceDRG = atomically $ + runSim (simChaChaT varRNG simId) drgNew } let NodeInfo diff --git a/ouroboros-consensus/test-storage/Test/Ouroboros/Storage/VolatileDB/Mock.hs b/ouroboros-consensus/test-storage/Test/Ouroboros/Storage/VolatileDB/Mock.hs index 8a18890db28..f666ae83496 100644 --- a/ouroboros-consensus/test-storage/Test/Ouroboros/Storage/VolatileDB/Mock.hs +++ b/ouroboros-consensus/test-storage/Test/Ouroboros/Storage/VolatileDB/Mock.hs @@ -7,7 +7,7 @@ import Control.Monad.State (StateT) import Ouroboros.Consensus.Util ((.:)) import Ouroboros.Consensus.Util.IOLike -import Ouroboros.Consensus.Util.STM (simStateT) +import Ouroboros.Consensus.Util.STM import Ouroboros.Storage.Common (castBlockComponent) import Ouroboros.Storage.Util.ErrorHandling (ThrowCantCatch) @@ -52,4 +52,4 @@ openDBMock err maxNumPerFile = do wrapModel :: StrictTVar m (DBModel blockId) -> StateT (DBModel blockId) (STM m) a -> STM m a - wrapModel dbVar = simStateT dbVar $ id + wrapModel dbVar = runSim (simStateT dbVar $ simId) diff --git a/ouroboros-network/test/Ouroboros/Network/BlockFetch/Examples.hs b/ouroboros-network/test/Ouroboros/Network/BlockFetch/Examples.hs index 9b9ca48b8c0..e1558c8ff36 100644 --- a/ouroboros-network/test/Ouroboros/Network/BlockFetch/Examples.hs +++ b/ouroboros-network/test/Ouroboros/Network/BlockFetch/Examples.hs @@ -311,7 +311,7 @@ mkTestFetchedBlockHeap :: (MonadSTM m, Ord (Point block)) => [Point block] -> m (TestFetchedBlockHeap m block) mkTestFetchedBlockHeap points = do - v <- atomically (newTVar (Set.fromList points)) + v <- newTVarM (Set.fromList points) return TestFetchedBlockHeap { getTestFetchedBlocks = readTVar v, addTestFetchedBlock = \p _b -> atomically (modifyTVar v (Set.insert p)) diff --git a/ouroboros-network/test/Ouroboros/Network/PeerSelection/Test.hs b/ouroboros-network/test/Ouroboros/Network/PeerSelection/Test.hs index 8c246d08c01..8738a6a6ba9 100644 --- a/ouroboros-network/test/Ouroboros/Network/PeerSelection/Test.hs +++ b/ouroboros-network/test/Ouroboros/Network/PeerSelection/Test.hs @@ -1,52 +1,53 @@ -{-# LANGUAGE BangPatterns #-} -{-# LANGUAGE RecordWildCards #-} -{-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GeneralisedNewtypeDeriving #-} -{-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# OPTIONS_GHC -Wno-orphans #-} module Ouroboros.Network.PeerSelection.Test (tests) where -import Data.Void (Void) -import Data.Function (on) -import Data.Typeable (Typeable) -import Data.Dynamic (fromDynamic) -import Data.Maybe (listToMaybe) import qualified Data.ByteString.Char8 as BS -import Data.List (nub, groupBy) +import Data.Dynamic (fromDynamic) +import Data.Function (on) +import Data.Graph (Graph) +import qualified Data.Graph as Graph +import Data.List (groupBy, nub) +import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as NonEmpty -import Data.List.NonEmpty (NonEmpty(..)) -import qualified Data.Map.Strict as Map import Data.Map.Strict (Map) -import qualified Data.Set as Set +import qualified Data.Map.Strict as Map +import Data.Maybe (listToMaybe) import Data.Set (Set) -import qualified Data.Graph as Graph -import Data.Graph (Graph) +import qualified Data.Set as Set import qualified Data.Tree as Tree +import Data.Typeable (Typeable) +import Data.Void (Void) +import Control.Exception (throw) import Control.Monad.Class.MonadAsync import Control.Monad.Class.MonadSTM import Control.Monad.Class.MonadTime -import Control.Tracer (Tracer(..), contramap, traceWith) -import Control.Exception (throw) +import Control.Tracer (Tracer (..), contramap, traceWith) -import Control.Monad.IOSim import Control.Monad.Class.MonadTimer +import Control.Monad.IOSim -import Ouroboros.Network.PeerSelection.Types -import Ouroboros.Network.PeerSelection.Governor hiding (PeerSelectionState(..)) +import qualified Network.DNS as DNS (defaultResolvConf) +import Ouroboros.Network.PeerSelection.Governor hiding + (PeerSelectionState (..)) import qualified Ouroboros.Network.PeerSelection.Governor as Governor import qualified Ouroboros.Network.PeerSelection.KnownPeers as KnownPeers import Ouroboros.Network.PeerSelection.RootPeersDNS -import qualified Network.DNS as DNS (defaultResolvConf) +import Ouroboros.Network.PeerSelection.Types import Test.QuickCheck -import Test.Tasty (TestTree, testGroup, localOption) -import Test.Tasty.QuickCheck (testProperty, QuickCheckMaxSize(..)) +import Test.Tasty (TestTree, localOption, testGroup) +import Test.Tasty.QuickCheck (QuickCheckMaxSize (..), testProperty) tests :: TestTree @@ -344,11 +345,11 @@ mockPeerSelectionPolicy GovernorMockEnvironment { policyGossipOverallTimeout = 10 -- seconds } -interpretPickScript :: (MonadSTM m, Ord peeraddr) - => TVar m PickScript +interpretPickScript :: (MonadSTMTx stm, Ord peeraddr) + => TVar_ stm PickScript -> Map peeraddr a -> Int - -> STM m (Set peeraddr) + -> stm (Set peeraddr) interpretPickScript scriptVar available pickNum | Map.null available = error "interpretPickScript: given empty map to pick from" @@ -637,7 +638,7 @@ _notionallyReachablePeers :: PeerGraph -> Set PeerAddr -> Set PeerAddr _notionallyReachablePeers pg roots = Set.fromList . map vertexToAddr - . concatMap Tree.flatten + . concatMap Tree.flatten . Graph.dfs graph . map addrToVertex $ Set.toList roots @@ -648,7 +649,7 @@ firstGossipReachablePeers :: PeerGraph -> Set PeerAddr -> Set PeerAddr firstGossipReachablePeers pg roots = Set.fromList . map vertexToAddr - . concatMap Tree.flatten + . concatMap Tree.flatten . Graph.dfs graph . map addrToVertex $ Set.toList roots @@ -913,7 +914,7 @@ initScript = newTVarM stepScript :: MonadSTM m => TVar m (Script a) -> m a stepScript scriptVar = atomically (stepScriptSTM scriptVar) -stepScriptSTM :: MonadSTM m => TVar m (Script a) -> STM m a +stepScriptSTM :: MonadSTMTx stm => TVar_ stm (Script a) -> stm a stepScriptSTM scriptVar = do Script (x :| xs) <- readTVar scriptVar case xs of @@ -1080,4 +1081,3 @@ _governorFindingPublicRoots targetNumberOfRootPeers domains = } pickTrivially :: Applicative m => Map IPv4 a -> Int -> m (Set IPv4) pickTrivially m n = pure . Set.take n . Map.keysSet $ m - diff --git a/ouroboros-network/test/Test/Pipe.hs b/ouroboros-network/test/Test/Pipe.hs index b95192a1fab..88e3ee781e7 100644 --- a/ouroboros-network/test/Test/Pipe.hs +++ b/ouroboros-network/test/Test/Pipe.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} {-# OPTIONS_GHC -Wno-orphans #-} @@ -27,7 +28,7 @@ import Control.Tracer import qualified Network.Mux.Bearer.Pipe as Mx import Ouroboros.Network.Mux -import Ouroboros.Network.Block (encodeTip, decodeTip) +import Ouroboros.Network.Block (decodeTip, encodeTip) import Ouroboros.Network.MockChain.Chain (Chain, ChainUpdate, Point) import qualified Ouroboros.Network.MockChain.Chain as Chain import qualified Ouroboros.Network.MockChain.ProducerState as CPS diff --git a/ouroboros-network/test/Test/Socket.hs b/ouroboros-network/test/Test/Socket.hs index 65ef000628a..b8a963c3700 100644 --- a/ouroboros-network/test/Test/Socket.hs +++ b/ouroboros-network/test/Test/Socket.hs @@ -7,6 +7,7 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} {-# OPTIONS_GHC -Wno-orphans #-} module Test.Socket (tests) where @@ -43,13 +44,13 @@ import qualified Network.TypedProtocol.ReqResp.Codec.CBOR as ReqResp import Control.Tracer -- TODO: remove Mx prefixes -import qualified Network.Mux as Mx hiding (MiniProtocolLimits(..)) +import qualified Network.Mux as Mx hiding (MiniProtocolLimits (..)) import qualified Network.Mux.Bearer.Socket as Mx import Ouroboros.Network.Mux as Mx import Ouroboros.Network.Socket -import Ouroboros.Network.Block (Tip, encodeTip, decodeTip) +import Ouroboros.Network.Block (Tip, decodeTip, encodeTip) import Ouroboros.Network.Magic import Ouroboros.Network.MockChain.Chain (Chain, ChainUpdate, Point) import qualified Ouroboros.Network.MockChain.Chain as Chain @@ -59,7 +60,8 @@ import qualified Ouroboros.Network.Protocol.ChainSync.Client as ChainSync import qualified Ouroboros.Network.Protocol.ChainSync.Codec as ChainSync import qualified Ouroboros.Network.Protocol.ChainSync.Examples as ChainSync import qualified Ouroboros.Network.Protocol.ChainSync.Server as ChainSync -import Ouroboros.Network.Protocol.Handshake.Type (acceptEq, cborTermVersionDataCodec) +import Ouroboros.Network.Protocol.Handshake.Type (acceptEq, + cborTermVersionDataCodec) import Ouroboros.Network.Protocol.Handshake.Version (simpleSingletonVersions) import Ouroboros.Network.Testing.Serialise diff --git a/ouroboros-network/test/Test/Subscription.hs b/ouroboros-network/test/Test/Subscription.hs index 82a04680214..b9f9c049912 100644 --- a/ouroboros-network/test/Test/Subscription.hs +++ b/ouroboros-network/test/Test/Subscription.hs @@ -1,12 +1,12 @@ {-# LANGUAGE BangPatterns #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE NamedFieldPuns #-} - {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} {-# OPTIONS_GHC -Wno-orphans #-} @@ -54,16 +54,15 @@ import Ouroboros.Network.Protocol.Handshake.Version (simpleSingletonVe import Ouroboros.Network.Magic import Ouroboros.Network.Mux -import Ouroboros.Network.NodeToNode hiding ( ipSubscriptionWorker - , dnsSubscriptionWorker - ) +import Ouroboros.Network.NodeToNode hiding (dnsSubscriptionWorker, + ipSubscriptionWorker) import Ouroboros.Network.Socket import Ouroboros.Network.Subscription -import Ouroboros.Network.Subscription.Ip import Ouroboros.Network.Subscription.Dns -import Ouroboros.Network.Subscription.Worker (WorkerParams (..)) +import Ouroboros.Network.Subscription.Ip import Ouroboros.Network.Subscription.PeerState import Ouroboros.Network.Subscription.Subscriber +import Ouroboros.Network.Subscription.Worker (WorkerParams (..)) defaultMiniProtocolLimit :: Int64 defaultMiniProtocolLimit = 3000000 @@ -595,10 +594,10 @@ prop_send_recv f xs first = ioProperty $ do data ReqRspCfg = ReqRspCfg { - rrcTag :: !String - , rrcServerVar :: !(StrictTMVar IO Int) - , rrcClientVar :: !(StrictTMVar IO [Int]) - , rrcSiblingVar :: !(StrictTVar IO Int) + rrcTag :: !String + , rrcServerVar :: !(StrictTMVar IO Int) + , rrcClientVar :: !(StrictTMVar IO [Int]) + , rrcSiblingVar :: !(StrictTVar IO Int) } newReqRspCfg :: String -> StrictTVar IO Int -> IO ReqRspCfg @@ -846,5 +845,3 @@ instance (Show a) => Show (WithTag a) where tagTrace :: String -> Tracer IO (WithTag a) -> Tracer IO a tagTrace tag tr = Tracer $ \s -> traceWith tr $ WithTag tag s - -