Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support for monad stacks #1539

Merged
merged 1 commit into from Feb 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
193 changes: 118 additions & 75 deletions io-sim-classes/src/Control/Monad/Class/MonadAsync.hs
@@ -1,12 +1,14 @@
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}

module Control.Monad.Class.MonadAsync
( MonadAsync (..)
, MonadAsyncSTM (..)
, AsyncCancelled(..)
, ExceptionInLinkedThread(..)
, link
Expand All @@ -21,25 +23,81 @@ import Control.Monad.Class.MonadFork
import Control.Monad.Class.MonadSTM
import Control.Monad.Class.MonadThrow

import Control.Monad (void)
import Control.Concurrent.Async (AsyncCancelled (..))
import qualified Control.Concurrent.Async as Async
import Control.Exception (SomeException)
import qualified Control.Exception as E
import qualified Control.Concurrent.Async as Async
import Control.Concurrent.Async (AsyncCancelled(..))
import Control.Monad (void)
import Control.Monad.Reader
import qualified Control.Monad.STM as STM
import Data.Proxy

class (Functor async, MonadSTMTx stm) => MonadAsyncSTM async stm where
{-# MINIMAL waitCatchSTM, pollSTM #-}

waitSTM :: async a -> stm a
pollSTM :: async a -> stm (Maybe (Either SomeException a))
waitCatchSTM :: async a -> stm (Either SomeException a)

default waitSTM :: MonadThrow stm => async a -> stm a
waitSTM action = waitCatchSTM action >>= either throwM return

waitAnySTM :: [async a] -> stm (async a, a)
waitAnyCatchSTM :: [async a] -> stm (async a, Either SomeException a)
waitEitherSTM :: async a -> async b -> stm (Either a b)
waitEitherSTM_ :: async a -> async b -> stm ()
waitEitherCatchSTM :: async a -> async b
-> stm (Either (Either SomeException a)
(Either SomeException b))
waitBothSTM :: async a -> async b -> stm (a, b)

default waitAnySTM :: MonadThrow stm => [async a] -> stm (async a, a)
default waitEitherSTM :: MonadThrow stm => async a -> async b -> stm (Either a b)
default waitEitherSTM_ :: MonadThrow stm => async a -> async b -> stm ()
default waitBothSTM :: MonadThrow stm => async a -> async b -> stm (a, b)

waitAnySTM as =
foldr orElse retry $
map (\a -> do r <- waitSTM a; return (a, r)) as

waitAnyCatchSTM as =
foldr orElse retry $
map (\a -> do r <- waitCatchSTM a; return (a, r)) as

waitEitherSTM left right =
(Left <$> waitSTM left)
`orElse`
(Right <$> waitSTM right)

waitEitherSTM_ left right =
(void $ waitSTM left)
`orElse`
(void $ waitSTM right)

waitEitherCatchSTM left right =
(Left <$> waitCatchSTM left)
`orElse`
(Right <$> waitCatchSTM right)

waitBothSTM left right = do
a <- waitSTM left
`orElse`
(waitSTM right >> retry)
b <- waitSTM right
return (a,b)

class ( MonadSTM m
, MonadThread m
, Functor (Async m)
, MonadAsyncSTM (Async m) (STM m)
) => MonadAsync m where

{-# MINIMAL async, asyncThreadId, cancel, cancelWith, waitCatchSTM, pollSTM #-}
{-# MINIMAL async, asyncThreadId, cancel, cancelWith #-}

-- | An asynchronous action
type Async m :: * -> *

async :: m a -> m (Async m a)
asyncThreadId :: proxy m -> Async m a -> ThreadId m
asyncThreadId :: Proxy m -> Async m a -> ThreadId m
withAsync :: m a -> (Async m a -> m b) -> m b

wait :: Async m a -> m a
Expand All @@ -49,10 +107,6 @@ class ( MonadSTM m
cancelWith :: Exception e => Async m a -> e -> m ()
uninterruptibleCancel :: Async m a -> m ()

waitSTM :: Async m a -> STM m a
pollSTM :: Async m a -> STM m (Maybe (Either SomeException a))
waitCatchSTM :: Async m a -> STM m (Either SomeException a)

waitAny :: [Async m a] -> m (Async m a, a)
waitAnyCatch :: [Async m a] -> m (Async m a, Either SomeException a)
waitAnyCancel :: [Async m a] -> m (Async m a, a)
Expand All @@ -70,15 +124,6 @@ class ( MonadSTM m
waitEither_ :: Async m a -> Async m b -> m ()
waitBoth :: Async m a -> Async m b -> m (a, b)

waitAnySTM :: [Async m a] -> STM m (Async m a, a)
waitAnyCatchSTM :: [Async m a] -> STM m (Async m a, Either SomeException a)
waitEitherSTM :: Async m a -> Async m b -> STM m (Either a b)
waitEitherSTM_ :: Async m a -> Async m b -> STM m ()
waitEitherCatchSTM :: Async m a -> Async m b
-> STM m (Either (Either SomeException a)
(Either SomeException b))
waitBothSTM :: Async m a -> Async m b -> STM m (a, b)

race :: m a -> m b -> m (Either a b)
race_ :: m a -> m b -> m ()
concurrently :: m a -> m b -> m (a,b)
Expand All @@ -87,7 +132,6 @@ class ( MonadSTM m
default withAsync :: MonadMask m => m a -> (Async m a -> m b) -> m b
default uninterruptibleCancel
:: MonadMask m => Async m a -> m ()
default waitSTM :: MonadThrow (STM m) => Async m a -> STM m a
default waitAnyCancel :: MonadThrow m => [Async m a] -> m (Async m a, a)
default waitAnyCatchCancel :: MonadThrow m => [Async m a]
-> m (Async m a, Either SomeException a)
Expand All @@ -97,12 +141,6 @@ class ( MonadSTM m
-> m (Either (Either SomeException a)
(Either SomeException b))

default waitAnySTM :: MonadThrow (STM m) => [Async m a] -> STM m (Async m a, a)
default waitEitherSTM :: MonadThrow (STM m) => Async m a -> Async m b -> STM m (Either a b)
default waitEitherSTM_ :: MonadThrow (STM m) => Async m a -> Async m b -> STM m ()
default waitBothSTM :: MonadThrow (STM m) => Async m a -> Async m b -> STM m (a, b)


withAsync action inner = mask $ \restore -> do
a <- async (restore action)
restore (inner a)
Expand All @@ -113,7 +151,6 @@ class ( MonadSTM m
waitCatch = atomically . waitCatchSTM

uninterruptibleCancel = uninterruptibleMask_ . cancel
waitSTM action = waitCatchSTM action >>= either throwM return

waitAny = atomically . waitAnySTM
waitAnyCatch = atomically . waitAnyCatchSTM
Expand All @@ -134,36 +171,6 @@ class ( MonadSTM m
waitEitherCatchCancel left right =
waitEitherCatch left right `finally` (cancel left >> cancel right)

waitAnySTM as =
foldr orElse retry $
map (\a -> do r <- waitSTM a; return (a, r)) as

waitAnyCatchSTM as =
foldr orElse retry $
map (\a -> do r <- waitCatchSTM a; return (a, r)) as

waitEitherSTM left right =
(Left <$> waitSTM left)
`orElse`
(Right <$> waitSTM right)

waitEitherSTM_ left right =
(void $ waitSTM left)
`orElse`
(void $ waitSTM right)

waitEitherCatchSTM left right =
(Left <$> waitCatchSTM left)
`orElse`
(Right <$> waitCatchSTM right)

waitBothSTM left right = do
a <- waitSTM left
`orElse`
(waitSTM right >> retry)
b <- waitSTM right
return (a,b)

race left right = withAsync left $ \a ->
withAsync right $ \b ->
waitEither a b
Expand All @@ -180,6 +187,17 @@ class ( MonadSTM m
-- Instance for IO uses the existing async library implementations
--

instance MonadAsyncSTM Async.Async STM.STM where
waitSTM = Async.waitSTM
pollSTM = Async.pollSTM
waitCatchSTM = Async.waitCatchSTM
waitAnySTM = Async.waitAnySTM
waitAnyCatchSTM = Async.waitAnyCatchSTM
waitEitherSTM = Async.waitEitherSTM
waitEitherSTM_ = Async.waitEitherSTM_
waitEitherCatchSTM = Async.waitEitherCatchSTM
waitBothSTM = Async.waitBothSTM

instance MonadAsync IO where

type Async IO = Async.Async
Expand All @@ -195,10 +213,6 @@ instance MonadAsync IO where
cancelWith = Async.cancelWith
uninterruptibleCancel = Async.uninterruptibleCancel

waitSTM = Async.waitSTM
pollSTM = Async.pollSTM
waitCatchSTM = Async.waitCatchSTM

waitAny = Async.waitAny
waitAnyCatch = Async.waitAnyCatch
waitAnyCancel = Async.waitAnyCancel
Expand All @@ -210,17 +224,46 @@ instance MonadAsync IO where
waitEither_ = Async.waitEither_
waitBoth = Async.waitBoth

waitAnySTM = Async.waitAnySTM
waitAnyCatchSTM = Async.waitAnyCatchSTM
waitEitherSTM = Async.waitEitherSTM
waitEitherSTM_ = Async.waitEitherSTM_
waitEitherCatchSTM = Async.waitEitherCatchSTM
waitBothSTM = Async.waitBothSTM

race = Async.race
race_ = Async.race_
concurrently = Async.concurrently

--
-- Lift to ReaderT
--

(.:) :: (c -> d) -> (a -> b -> c) -> (a -> b -> d)
(f .: g) x y = f (g x y)

instance MonadAsync m => MonadAsync (ReaderT r m) where
type Async (ReaderT r m) = Async m

asyncThreadId _ = asyncThreadId (Proxy @m)

async (ReaderT ma) = ReaderT $ \r -> async (ma r)
withAsync (ReaderT ma) f = ReaderT $ \r -> withAsync (ma r) $ \a -> runReaderT (f a) r

race (ReaderT ma) (ReaderT mb) = ReaderT $ \r -> race (ma r) (mb r)
race_ (ReaderT ma) (ReaderT mb) = ReaderT $ \r -> race_ (ma r) (mb r)
concurrently (ReaderT ma) (ReaderT mb) = ReaderT $ \r -> concurrently (ma r) (mb r)

wait = lift . wait
poll = lift . poll
waitCatch = lift . waitCatch
cancel = lift . cancel
uninterruptibleCancel = lift . uninterruptibleCancel
cancelWith = lift .: cancelWith
waitAny = lift . waitAny
waitAnyCatch = lift . waitAnyCatch
waitAnyCancel = lift . waitAnyCancel
waitAnyCatchCancel = lift . waitAnyCatchCancel
waitEither = lift .: waitEither
waitEitherCatch = lift .: waitEitherCatch
waitEitherCancel = lift .: waitEitherCancel
waitEitherCatchCancel = lift .: waitEitherCatchCancel
waitEither_ = lift .: waitEither_
waitBoth = lift .: waitBoth

--
-- Linking
--
Expand Down Expand Up @@ -275,7 +318,7 @@ linkToOnly tid shouldThrow a = do
r <- waitCatch a
case r of
Left e | shouldThrow e -> throwTo tid (exceptionInLinkedThread e)
_otherwise -> return ()
_otherwise -> return ()
where
exceptionInLinkedThread :: SomeException -> ExceptionInLinkedThread
exceptionInLinkedThread =
Expand Down
5 changes: 4 additions & 1 deletion io-sim-classes/src/Control/Monad/Class/MonadST.hs
@@ -1,7 +1,8 @@
{-# LANGUAGE RankNTypes #-}
module Control.Monad.Class.MonadST where

import Control.Monad.ST (ST, stToIO)
import Control.Monad.Reader
import Control.Monad.ST (ST, stToIO)


-- | This class is for abstracting over 'stToIO' which allows running 'ST'
Expand Down Expand Up @@ -29,3 +30,5 @@ instance MonadST IO where
instance MonadST (ST s) where
withLiftST = \f -> f id

instance MonadST m => MonadST (ReaderT r m) where
withLiftST f = withLiftST $ \g -> f (lift . g)