Skip to content

Commit

Permalink
refactoring, cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
mitchellwrosen committed Nov 28, 2023
1 parent 59312dc commit f5e9a45
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 74 deletions.
28 changes: 16 additions & 12 deletions ki/src/Ki/Internal/IO.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ where
import Control.Exception
import Control.Monad (join)
import Data.Coerce (coerce)
import Data.Maybe (isJust)
import GHC.Base (maskAsyncExceptions#, maskUninterruptible#)
import GHC.Conc (STM, ThreadId (ThreadId), catchSTM)
import GHC.Exts (Int (I#), fork#, forkOn#)
import GHC.IO (IO (IO))
import Prelude

-- A little promise that this IO action cannot throw an exception.
-- A little promise that this IO action cannot throw an exception (*including* async exceptions, which you normally
-- think of as being able to strike at any time).
--
-- Yeah it's verbose, and maybe not that necessary, but the code that bothers to use it really does require
-- un-exceptiony IO actions for correctness, so here we are.
Expand All @@ -42,13 +44,17 @@ data IOResult a
= Failure !SomeException -- sync or async exception
| Success a

-- Try an action, catching any exception it throws.
--
-- The caller is responsible for ensuring that async exceptions are masked (at whatever masking level is appropriate),
-- as (again) `UnexceptionalIO` implies async exceptions won't be thrown either.
unexceptionalTry :: forall a. IO a -> UnexceptionalIO (IOResult a)
unexceptionalTry action =
UnexceptionalIO do
(Success <$> action) `catch` \exception ->
pure (Failure exception)

-- Like try, but with continuations. Also, catches all exceptions, because that's the only flavor we need.
-- Like try, but with continuations.
unexceptionalTryEither ::
forall a b.
(SomeException -> UnexceptionalIO b) ->
Expand All @@ -63,20 +69,18 @@ unexceptionalTryEither onFailure onSuccess action =
(pure . coerce @_ @(SomeException -> IO b) onFailure)

isAsyncException :: SomeException -> Bool
isAsyncException exception =
case fromException @SomeAsyncException exception of
Nothing -> False
Just _ -> True
isAsyncException =
isJust . fromException @SomeAsyncException

-- | Call an action with asynchronous exceptions interruptibly masked.
interruptiblyMasked :: IO a -> IO a
interruptiblyMasked (IO io) =
IO (maskAsyncExceptions# io)
interruptiblyMasked :: forall a. IO a -> IO a
interruptiblyMasked =
coerce (maskAsyncExceptions# @a)

-- | Call an action with asynchronous exceptions uninterruptibly masked.
uninterruptiblyMasked :: IO a -> IO a
uninterruptiblyMasked (IO io) =
IO (maskUninterruptible# io)
uninterruptiblyMasked :: forall a. IO a -> IO a
uninterruptiblyMasked =
coerce (maskUninterruptible# @a)

-- Like try, but with continuations
tryEitherSTM :: (Exception e) => (e -> STM b) -> (a -> STM b) -> STM a -> STM b
Expand Down
133 changes: 71 additions & 62 deletions ki/src/Ki/Internal/Scope.hs
Original file line number Diff line number Diff line change
Expand Up @@ -223,63 +223,74 @@ allocateScope = do

-- Spawn a thread in a scope, providing it its child id and a function that sets the masking state to the requested
-- masking state. The given action is called with async exceptions interruptibly masked.
spawn :: Scope -> ThreadOptions -> (Tid -> (forall x. IO x -> IO x) -> UnexceptionalIO ()) -> IO ThreadId
spawn
Scope {childrenVar, nextChildIdSupply, statusVar}
ThreadOptions {affinity, allocationLimit, label, maskingState = requestedChildMaskingState}
action = do
-- Interruptible mask is enough so long as none of the STM operations below block.
--
-- Unconditionally set masking state to MaskedInterruptible, even though we might already be at MaskedInterruptible
-- or MaskedUninterruptible, to avoid a branch on parentMaskingState.
interruptiblyMasked do
-- Record the thread as being about to start. Not allowed to retry.
nonblockingAtomically do
n <- nonblockingReadTVar statusVar
assert (n >= -2) do
case n of
Open -> nonblockingWriteTVar' statusVar (n + 1)
Closing -> nonblockingThrowSTM ScopeClosing
Closed -> nonblockingThrowSTM (ErrorCall "ki: scope closed")

childId <- IntSupply.next nextChildIdSupply

childThreadId <-
forkWithAffinity affinity do
when (not (null label)) do
childThreadId <- myThreadId
labelThread childThreadId label

for_ allocationLimit \bytes -> do
setAllocationCounter (byteCountToInt64 bytes)
enableAllocationLimit

let -- Action that sets the masking state from the current (MaskedInterruptible) to the requested one.
atRequestedMaskingState :: IO a -> IO a
atRequestedMaskingState =
case requestedChildMaskingState of
Unmasked -> unsafeUnmask
MaskedInterruptible -> id
MaskedUninterruptible -> uninterruptiblyMasked

runUnexceptionalIO (action childId atRequestedMaskingState)

nonblockingAtomically (unrecordChild childrenVar childId)

-- Record the child as having started. Not allowed to retry.
nonblockingAtomically do
n <- nonblockingReadTVar statusVar
nonblockingWriteTVar' statusVar (n - 1)
recordChild childrenVar childId childThreadId

pure childThreadId
spawn :: Scope -> ThreadOptions -> (Tid -> (forall x. IO x -> IO x) -> UnexceptionalIO ()) -> IO ChildIds
spawn scope@Scope {childrenVar, statusVar} options action = do
-- Interruptible mask is enough so long as none of the STM operations below block.
--
-- Unconditionally set masking state to MaskedInterruptible, even though we might already be at MaskedInterruptible
-- or MaskedUninterruptible, to avoid a branch on parentMaskingState.
interruptiblyMasked do
-- Record the thread as being about to start. Not allowed to retry.
nonblockingAtomically do
status <- nonblockingReadTVar statusVar
assert (status >= -2) do
case status of
Open -> nonblockingWriteTVar' statusVar (status + 1)
Closing -> nonblockingThrowSTM ScopeClosing
Closed -> nonblockingThrowSTM (ErrorCall "ki: scope closed")

childIds <- spawnChild scope options action

-- Record the child as having started. Not allowed to retry.
nonblockingAtomically do
starting <- nonblockingReadTVar statusVar
assert (starting >= 1) do
nonblockingWriteTVar' statusVar (starting - 1)
recordChild childrenVar childIds

pure childIds

data ChildIds
= ChildIds
{-# UNPACK #-} !Tid
{-# UNPACK #-} !ThreadId

spawnChild :: Scope -> ThreadOptions -> (Tid -> (forall x. IO x -> IO x) -> UnexceptionalIO ()) -> IO ChildIds
spawnChild scope options action = do
childId <- IntSupply.next nextChildIdSupply
childThreadId <-
forkWithAffinity affinity do
when (not (null label)) do
childThreadId <- myThreadId
labelThread childThreadId label

for_ allocationLimit \bytes -> do
setAllocationCounter (byteCountToInt64 bytes)
enableAllocationLimit

let -- Action that sets the masking state from the current (MaskedInterruptible) to the requested one.
atRequestedMaskingState :: IO a -> IO a
atRequestedMaskingState =
case requestedChildMaskingState of
Unmasked -> unsafeUnmask
MaskedInterruptible -> id
MaskedUninterruptible -> uninterruptiblyMasked

runUnexceptionalIO (action childId atRequestedMaskingState)

nonblockingAtomically (unrecordChild childrenVar childId)
pure (ChildIds childId childThreadId)
where
Scope {childrenVar, nextChildIdSupply} = scope
ThreadOptions {affinity, allocationLimit, label, maskingState = requestedChildMaskingState} = options
{-# INLINE spawnChild #-}

-- Record our child by either:
--
-- * Flipping `Nothing` to `Just childThreadId` (common case: we record child before it unrecords itself)
-- * Flipping `Just _` to `Nothing` (uncommon case: we observe that a child already unrecorded itself)
recordChild :: TVar (IntMap ThreadId) -> Tid -> ThreadId -> NonblockingSTM ()
recordChild childrenVar childId childThreadId = do
recordChild :: TVar (IntMap ThreadId) -> ChildIds -> NonblockingSTM ()
recordChild childrenVar (ChildIds childId childThreadId) = do
children <- nonblockingReadTVar childrenVar
nonblockingWriteTVar' childrenVar (IntMap.Lazy.alter (maybe (Just childThreadId) (const Nothing)) childId children)

Expand All @@ -298,7 +309,7 @@ awaitAll Scope {childrenVar, statusVar} = do
children <- readTVar childrenVar
guard (IntMap.Lazy.null children)
status <- readTVar statusVar
case status of
assert (status >= -2) case status of
Open -> guard (status == 0)
Closing -> retry -- block until closed
Closed -> pure ()
Expand All @@ -321,14 +332,12 @@ forkWith :: Scope -> ThreadOptions -> IO a -> IO (Thread a)
forkWith scope opts action = do
resultVar <- newTVarIO NoResultYet
let done result = UnexceptionalIO (atomically (writeTVar resultVar result))
ident <-
ChildIds _ childThreadId <-
spawn scope opts \childId masking -> do
result <- unexceptionalTry (masking action)
case result of
unexceptionalTry (masking action) >>= \case
Failure exception -> do
when
(not (isScopeClosingException exception))
(propagateException scope childId exception)
when (not (isScopeClosingException exception)) do
propagateException scope childId exception
-- even put async exceptions that we propagated. this isn't totally ideal because a caller awaiting this
-- thread would not be able to distinguish between async exceptions delivered to this thread, or itself
done (BadResult exception)
Expand All @@ -338,7 +347,7 @@ forkWith scope opts action = do
NoResultYet -> retry
BadResult exception -> throwSTM exception
GoodResult value -> pure value
pure (makeThread ident doAwait)
pure (makeThread childThreadId doAwait)

-- | Variant of 'Ki.forkWith' for threads that never return.
forkWith_ :: Scope -> ThreadOptions -> IO Void -> IO ()
Expand Down Expand Up @@ -369,7 +378,7 @@ forkTryWith :: forall e a. (Exception e) => Scope -> ThreadOptions -> IO a -> IO
forkTryWith scope opts action = do
resultVar <- newTVarIO NoResultYet
let done result = UnexceptionalIO (atomically (writeTVar resultVar result))
childThreadId <-
ChildIds _ childThreadId <-
spawn scope opts \childId masking -> do
result <- unexceptionalTry (masking action)
case result of
Expand Down Expand Up @@ -427,7 +436,7 @@ forkTryWith scope opts action = do
propagateException :: Scope -> Tid -> SomeException -> UnexceptionalIO ()
propagateException Scope {childExceptionVar, parentThreadId, statusVar} childId exception =
UnexceptionalIO (readTVarIO statusVar) >>= \case
Closing -> tryPutChildExceptionVar -- (A) / (B)
Closing -> tryPutChildExceptionVar -- (A) or (B), we don't care which
status -> assert (status >= 0) loop -- we know status is Open here
where
loop :: UnexceptionalIO ()
Expand Down

0 comments on commit f5e9a45

Please sign in to comment.