Skip to content

Commit

Permalink
Sharing recovery fix
Browse files Browse the repository at this point in the history
The NodeCounts data structure has changed substantially. It is no longer a list of
node/count pairs with the invariant that all child nodes appear after parent nodes.  It is
now a list of "dependency groups" each dependency group is a directed acyclic graph (DAG)
that captures the dependencies that an AccSharing node can have on VarSharing nodes.

The function 'filterCompleted' now checks and filters out nodes on a per-dependency group
basis.

This commit also introduces a new trace function in D.A.A.Debug and adds tests for sharing
recovery.

There are also some new dependencies added to accelerate.cabal because we are using
unordered containers.
  • Loading branch information
Sean Seefried committed May 13, 2011
1 parent c16f19d commit 5c24257
Show file tree
Hide file tree
Showing 6 changed files with 380 additions and 38 deletions.
11 changes: 10 additions & 1 deletion Data/Array/Accelerate/Debug.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
module Data.Array.Accelerate.Debug (

-- * Conditional tracing
initTrace, queryTrace, traceLine, traceChunk
initTrace, queryTrace, traceLine, traceChunk, trace

) where

Expand All @@ -24,6 +24,7 @@ import Control.Monad
import Data.IORef
import System.IO
import System.IO.Unsafe (unsafePerformIO)
import qualified Debug.Trace as Trace

-- friends
import Data.Array.Accelerate.Pretty ()
Expand Down Expand Up @@ -66,3 +67,11 @@ traceChunk header msg
; when doTrace
$ hPutStrLn stderr (header ++ "\n " ++ msg)
}

-- | Like Debug.Trace but only when the /trace flag/ is set.
trace :: String -> a -> a
trace msg v = unsafePerformIO $ do
doTrace <- queryTrace
if doTrace
then return $ Trace.trace msg v
else return v
219 changes: 188 additions & 31 deletions Data/Array/Accelerate/Smart.hs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ import Data.Typeable
import System.Mem.StableName
import System.IO.Unsafe (unsafePerformIO)
import Prelude hiding (exp)
import Text.Printf
import Data.Hashable
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as Map
import Data.HashSet (HashSet)
import qualified Data.HashSet as Set

-- friends
import Data.Array.Accelerate.Debug
Expand All @@ -92,7 +98,6 @@ import Data.Array.Accelerate.Pretty ()
floatOutAccFromExp :: Bool
floatOutAccFromExp = True


-- Layouts
-- -------

Expand Down Expand Up @@ -478,6 +483,9 @@ instance Eq StableAccName where
| Just sn1' <- gcast sn1 = sn1' == sn2
| otherwise = False

instance Hashable StableAccName where
hash (StableAccName sn) = hashStableName sn

makeStableAcc :: Acc arrs -> IO (StableName (Acc arrs))
makeStableAcc acc = acc `seq` makeStableName acc

Expand Down Expand Up @@ -844,16 +852,83 @@ makeOccMap rootAcc
travTup NilTup = return NilTup
travTup (SnocTup tup e) = pure SnocTup <*> travTup tup <*> traverseExp updateMap enter e

-- Type used to maintain how often each shared subterm occured.
--
-- Invariant: If one shared term 's' is itself a subterm of another shared term 't', then 's'
-- must occur *after* 't' in the 'NodeCounts'. Moreover, no shared term occur twice.
-- 'NodeCounts' is a type used to maintain how often each shared subterm occured and the
-- dependencies between shared subterms.
--
-- During phase 2 of the algorithm, when a shared subterm is replaced with a 'VarSharing'
-- node it may already contain another 'VarSharing' node. In this case it is said to have
-- a dependency on it. The 'DepGroup' data type captures these dependencies.
--
-- Invariant (for DepGroup): If one shared term 's' is itself a subterm of another shared
-- term 't', then edge '(t,s)' must appear in the DepGroup.
--
-- Invariant (for NodeCounts): The 'DepGroup's do not overlap. That is, for any two
-- 'DepGroup's, 'dg1' and 'dg2' the intersection of the nodes of dg1 and dg2 is empty.
-- The order of DepGroups themselves does not matter.
--
--
-- To ensure the invariant is preserved over merging node counts from sibling subterms,
-- the function '(+++)' must be used.
--

-- We use the 'HashMap' and 'HashSet' data structure of the 'unordered-containers' package
-- because 'StableAccName' does not have an 'Ord' instance.
data DepGroup = DepGroup { nodes :: HashMap StableAccName (StableSharingAcc, Int)
, edges :: HashMap StableAccName (HashSet StableAccName) }

-- Nicer output for debugging.
instance Show DepGroup where
show dg = printf "DepGroup { nodes = %s, edges = [%s] }"
(show (map snd $ Map.toList (nodes dg)))
(showNodes (Map.toList $ edges dg))
where
showNodes [] = ""
showNodes [s] = showNodes1 s
showNodes (s:ss) = printf "%s, %s" (showNodes1 s) (showNodes ss)
showNodes1 :: (StableAccName, HashSet StableAccName) -> String
showNodes1 (sa,set) = printf "(%s, %s)" (show sa) (show (Set.toList set))

newtype NodeCounts = NodeCounts [DepGroup] deriving Show

emptyDepGroup :: DepGroup
emptyDepGroup = DepGroup { nodes = Map.empty, edges = Map.empty }

--
-- To ensure the invariant is preserved over merging node counts from sibling subterms, the
-- function '(+++)' must be used.
-- Merges two dependency groups.
--
newtype NodeCounts = NodeCounts [(StableSharingAcc, Int)]
deriving Show
mergeDepGroup :: DepGroup -> DepGroup -> DepGroup
mergeDepGroup dg1 dg2 = DepGroup newNodes newEdges
where
newNodes = Map.foldlWithKey' (\m k v -> Map.insertWith updateCount k v m)
(nodes dg1) (nodes dg2)
updateCount (sa1, count1) (sa2, count2) = (sa1 `pickNoneVar` sa2, count1 + count2)
newEdges = Map.foldlWithKey' (\m k v -> Map.insertWith Set.union k v m)
(edges dg1) (edges dg2)


depGroupInsertNode :: StableSharingAcc -> DepGroup -> DepGroup
depGroupInsertNode sa dg = dg { nodes = newNodes }
where
san = stableAccNameOf sa
newNode = case Map.lookup san (nodes dg) of
Just (sa', count) -> (sa `pickNoneVar` sa', 1 + count)
Nothing -> (sa, 1)
newNodes = Map.insert san newNode (nodes dg)

-- Precondition: The node must already be a member
depGroupInsertEdge :: StableSharingAcc -> StableAccName -> DepGroup -> DepGroup
depGroupInsertEdge src tgtSA dg = dg { edges = newEdges }
where
srcSA = stableAccNameOf src
newEdges = Map.insertWith Set.union srcSA (Set.singleton tgtSA) (edges dg)

pickNoneVar :: StableSharingAcc -> StableSharingAcc -> StableSharingAcc
(StableSharingAcc _ (VarSharing _)) `pickNoneVar` sa2 = sa2
sa1 `pickNoneVar` _sa2 = sa1

stableAccNameOf :: StableSharingAcc -> StableAccName
stableAccNameOf (StableSharingAcc sn _) = StableAccName sn

-- Empty node counts
--
Expand All @@ -862,8 +937,24 @@ noNodeCounts = NodeCounts []

-- Singleton node counts
--
nodeCount :: (StableSharingAcc, Int) -> NodeCounts
nodeCount nc = NodeCounts [nc]
-- Merges all the 'DepGroup's in 'subCounts', add the node 'stableSharingAcc' with a
-- sharing count of 1, and also adds edges from this node to all the nodes in the merged
-- dependency group.
--
nodeCount :: StableSharingAcc -> NodeCounts -> NodeCounts
nodeCount stableSharingAcc (NodeCounts subCounts) =
NodeCounts $ [depGroup]
where
san = stableAccNameOf stableSharingAcc
mergedDepGroup :: DepGroup
mergedDepGroup = foldl mergeDepGroup emptyDepGroup subCounts
depGroup :: DepGroup
-- Adds an edge for each node in 'mergedDepGroup'. This is probably
-- overkill. It would only be necessary to add an edge to the "root" node
-- of each 'DepGroup' in 'subCounts'.
depGroup = Set.foldr (depGroupInsertEdge stableSharingAcc)
(depGroupInsertNode stableSharingAcc mergedDepGroup)
(keysSet (nodes mergedDepGroup))

-- Combine node counts that belong to the same node.
--
Expand All @@ -874,19 +965,62 @@ nodeCount nc = NodeCounts [nc]
-- nesting depth, but doesn't seem worthwhile as the arguments are expected to be fairly short.
-- Change if profiling suggests that this function is a bottleneck.
--

(+++) :: NodeCounts -> NodeCounts -> NodeCounts
NodeCounts us +++ NodeCounts vs = NodeCounts $ merge us vs
NodeCounts us +++ NodeCounts vs = NodeCounts $
let result = merge us vs
in if length us > 0 && length vs > 0
then trace (printf " %s\n `merge`\n %s\n ==\n %s"
(show us) (show vs) (show result)) result
else result
where
merge :: [DepGroup] -> [DepGroup] -> [DepGroup]
merge [] ys = ys
merge xs [] = xs
merge xs@(x@(sa1, count1) : xs') ys@(y@(sa2, count2) : ys')
| sa1 == sa2 = (sa1 `pickNoneVar` sa2, count1 + count2) : merge xs' ys'
| sa1 `notElem` map fst ys' = x : merge xs' ys
| sa2 `notElem` map fst xs' = y : merge xs ys'
| otherwise = INTERNAL_ERROR(error) "(+++)" "Precondition violated"
merge (xs:xss) yss = mergeInto xs (merge xss yss)

mergeInto :: DepGroup -> [DepGroup] -> [DepGroup]
mergeInto xs [] = [xs]
mergeInto xs (ys:yss)
| overlap xs ys = mergeInto (mergeDepGroup xs ys) yss
| otherwise = ys:mergeInto xs yss

-- Note: This is quadratic in complexity. HashSet does not have an 'intersection' method.
overlap :: DepGroup -> DepGroup -> Bool
overlap dg1 dg2 = overlap' (Map.keys $ nodes dg1) (Map.keys $ nodes dg2)
where
overlap' [] _ = False
overlap' (x:xs) ys = x `elem` ys || overlap' xs ys

(StableSharingAcc _ (VarSharing _)) `pickNoneVar` sa2 = sa2
sa1 `pickNoneVar` _sa2 = sa1
keysSet :: (Eq k, Hashable k) => HashMap k a -> HashSet k
keysSet = Set.fromList . Map.keys

type TopoSortState = (HashSet StableAccName, [(StableSharingAcc, Int)])

--
-- Returns the 'StableSharingAcc's for a 'DepGroup' along with the sharing count
-- in reverse-binding order.
--
stableSharingAccsForDepGroup :: DepGroup -> [(StableSharingAcc, Int)]
stableSharingAccsForDepGroup dg = topoSort (Map.keys $ nodes dg) Set.empty []
where
-- topological sort
topoSort :: [StableAccName] -> HashSet StableAccName -> [(StableSharingAcc, Int)]
-> [(StableSharingAcc, Int)]
topoSort [] _ accum = accum
topoSort (san:sans) visited accum =
let (visited', result) = visit san (visited, accum)
in topoSort sans visited' accum ++ result
visit :: StableAccName -> TopoSortState -> TopoSortState
visit san this@(visited, accum)
| Set.member san visited = this
| otherwise =
let visited' = Set.insert san visited
element = fromJust $ Map.lookup san $ nodes dg
(visited'', accum'') = case Map.lookup san (edges dg) of
Just succs -> Set.foldr visit (visited', accum) succs
Nothing -> (visited', accum)
in (visited'', element : accum'')

-- Determine the scopes of all variables representing shared subterms (Phase Two) in a bottom-up
-- sweep. The first argument determines whether array computations are floated out of expressions
Expand All @@ -901,7 +1035,10 @@ determineScopes floatOutAcc occMap rootAcc = fst $ scopesAcc rootAcc
scopesAcc (LetSharing _ _)
= INTERNAL_ERROR(error) "determineScopes: scopes" "unexpected 'LetSharing'"
scopesAcc sharingAcc@(VarSharing sn)
= (VarSharing sn, nodeCount (StableSharingAcc sn sharingAcc, 1))
= trace debugMsg (VarSharing sn, newCount)
where
newCount = nodeCount (StableSharingAcc sn sharingAcc) noNodeCounts
debugMsg = printf "%s: (VarSharing) %s" (show $ StableAccName sn) (show newCount)
scopesAcc (AccSharing sn pacc)
= case pacc of
Atag i -> reconstruct (Atag i) noNodeCounts
Expand Down Expand Up @@ -1055,36 +1192,53 @@ determineScopes floatOutAcc occMap rootAcc = fst $ scopesAcc rootAcc
--
reconstruct :: Arrays arrs
=> PreAcc SharingAcc arrs -> NodeCounts -> (SharingAcc arrs, NodeCounts)
reconstruct newAcc subCount
| occCount > 1 = ( VarSharing sn
, nodeCount (StableSharingAcc sn sharingAcc, 1) +++ newCount)
| otherwise = (sharingAcc, newCount)
reconstruct newAcc subCount = trace debugMsg reconstruct'
where
debugMsg = printf ("%s: bindHere = %s\n subCount = %s" ++
"\n newCount = %s" ++
"\n newNodeCounts = %s")
(show (StableAccName sn))
(show $ bindHere)
(show $ subCount)
(show $ newCount)
(if occCount > 1 then show
(nodeCount (StableSharingAcc sn sharingAcc) newCount) else "---")
reconstruct'
| occCount > 1 = (VarSharing sn, nodeCount (StableSharingAcc sn sharingAcc) newCount)
| otherwise = (sharingAcc, newCount)
-- Determine the bindings that need to be attached to the current node...
(newCount, bindHere) = filterCompleted subCount

-- ...and wrap them in 'LetSharing' constructors
lets = foldl (flip (.)) id . map LetSharing $ bindHere
sharingAcc = lets $ AccSharing sn newAcc

-- Extract nodes that have a complete node count (i.e., their node count is equal to the
-- number of occurences of that node in the overall expression) => nodes with a completed
-- node count should be let bound at the currently processed node.
-- Extract nodes that have a complete node count (i.e., their node count is equal
-- to the number of occurences of that node in the overall expression) => nodes
-- with a completed node count should be let bound at the currently processed
-- node.
--
-- Nodes are extracted on a per dependency group ('DepGroup') basis. If all the
-- nodes in a dependency group have a sharing count equal to their occurrence
-- counts then they are filtered out.
--
filterCompleted :: NodeCounts -> (NodeCounts, [StableSharingAcc])
filterCompleted (NodeCounts counts)
= let (counts', completed) = fc counts
in (NodeCounts counts', completed)
where
fc [] = ([], [])
fc (sub@(sa, n):subs)
fc (sub:subs)
-- current node is the binding point for the shared node 'sa'
| occCount == n = (subs', sa:bindHere)
| readyToBind sub =
let sas = map fst $ stableSharingAccsForDepGroup sub
in (subs', sas ++ bindHere)
-- not a binding point
| otherwise = (sub:subs', bindHere)
where
occCount = lookupWithSharingAcc occMap sa
(subs', bindHere) = fc subs
readyToBind :: DepGroup -> Bool
readyToBind dg = all (\(sa,n) -> lookupWithSharingAcc occMap sa == n)
(stableSharingAccsForDepGroup dg)

scopesExp :: forall arrs. SharingExp arrs -> (SharingExp arrs, NodeCounts)
scopesExp pacc
Expand Down Expand Up @@ -1151,10 +1305,13 @@ determineScopes floatOutAcc occMap rootAcc = fst $ scopesAcc rootAcc
-> (SharingExp b, NodeCounts)
maybeFloatOutAcc c acc@(VarSharing _) accCount = (c acc, accCount) -- nothing to float out
maybeFloatOutAcc c acc accCount
| floatOutAcc = (c var, nodeCount (stableAcc, 1) +++ accCount)
| floatOutAcc = trace debugMsg (c var, floatedCount)
| otherwise = (c acc, accCount)
where
(var, stableAcc) = abstract acc id
floatedCount = nodeCount stableAcc accCount
debugMsg = printf "Floating out %s:\n nodeCount = %s"
(show stableAcc) (show $ floatedCount)

abstract :: SharingAcc a -> (SharingAcc a -> SharingAcc a)
-> (SharingAcc a, StableSharingAcc)
Expand Down
1 change: 0 additions & 1 deletion accelerate-examples/src/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,3 @@ main = do
valid <- runVerify config tests
--
unless (null valid || cfgVerify config) $ runTiming config valid

0 comments on commit 5c24257

Please sign in to comment.