From 5c2425791a4619f3e9bc0b7c47119183d029906f Mon Sep 17 00:00:00 2001 From: Sean Seefried Date: Fri, 13 May 2011 16:39:50 +1000 Subject: [PATCH] Sharing recovery fix The NodeCounts data structure has changed substantially. It is no longer a list of node/count pairs with the invariant that all child nodes appear after parent nodes. It is now a list of "dependency groups" each dependency group is a directed acyclic graph (DAG) that captures the dependencies that an AccSharing node can have on VarSharing nodes. The function 'filterCompleted' now checks and filters out nodes on a per-dependency group basis. This commit also introduces a new trace function in D.A.A.Debug and adds tests for sharing recovery. There are also some new dependencies added to accelerate.cabal because we are using unordered containers. --- Data/Array/Accelerate/Debug.hs | 11 +- Data/Array/Accelerate/Smart.hs | 219 +++++++++++++++--- accelerate-examples/src/Main.hs | 1 - accelerate-examples/src/Test.hs | 18 +- .../tests/simple/SharingRecovery.hs | 165 +++++++++++++ accelerate.cabal | 4 +- 6 files changed, 380 insertions(+), 38 deletions(-) create mode 100644 accelerate-examples/tests/simple/SharingRecovery.hs diff --git a/Data/Array/Accelerate/Debug.hs b/Data/Array/Accelerate/Debug.hs index d8a44acb7..d55611e8e 100644 --- a/Data/Array/Accelerate/Debug.hs +++ b/Data/Array/Accelerate/Debug.hs @@ -15,7 +15,7 @@ module Data.Array.Accelerate.Debug ( -- * Conditional tracing - initTrace, queryTrace, traceLine, traceChunk + initTrace, queryTrace, traceLine, traceChunk, trace ) where @@ -24,6 +24,7 @@ import Control.Monad import Data.IORef import System.IO import System.IO.Unsafe (unsafePerformIO) +import qualified Debug.Trace as Trace -- friends import Data.Array.Accelerate.Pretty () @@ -66,3 +67,11 @@ traceChunk header msg ; when doTrace $ hPutStrLn stderr (header ++ "\n " ++ msg) } + +-- | Like Debug.Trace but only when the /trace flag/ is set. +trace :: String -> a -> a +trace msg v = unsafePerformIO $ do + doTrace <- queryTrace + if doTrace + then return $ Trace.trace msg v + else return v \ No newline at end of file diff --git a/Data/Array/Accelerate/Smart.hs b/Data/Array/Accelerate/Smart.hs index 783b407f2..1fb621b19 100644 --- a/Data/Array/Accelerate/Smart.hs +++ b/Data/Array/Accelerate/Smart.hs @@ -68,6 +68,12 @@ import Data.Typeable import System.Mem.StableName import System.IO.Unsafe (unsafePerformIO) import Prelude hiding (exp) +import Text.Printf +import Data.Hashable +import Data.HashMap.Strict (HashMap) +import qualified Data.HashMap.Strict as Map +import Data.HashSet (HashSet) +import qualified Data.HashSet as Set -- friends import Data.Array.Accelerate.Debug @@ -92,7 +98,6 @@ import Data.Array.Accelerate.Pretty () floatOutAccFromExp :: Bool floatOutAccFromExp = True - -- Layouts -- ------- @@ -478,6 +483,9 @@ instance Eq StableAccName where | Just sn1' <- gcast sn1 = sn1' == sn2 | otherwise = False +instance Hashable StableAccName where + hash (StableAccName sn) = hashStableName sn + makeStableAcc :: Acc arrs -> IO (StableName (Acc arrs)) makeStableAcc acc = acc `seq` makeStableName acc @@ -844,16 +852,83 @@ makeOccMap rootAcc travTup NilTup = return NilTup travTup (SnocTup tup e) = pure SnocTup <*> travTup tup <*> traverseExp updateMap enter e --- Type used to maintain how often each shared subterm occured. -- --- Invariant: If one shared term 's' is itself a subterm of another shared term 't', then 's' --- must occur *after* 't' in the 'NodeCounts'. Moreover, no shared term occur twice. +-- 'NodeCounts' is a type used to maintain how often each shared subterm occured and the +-- dependencies between shared subterms. +-- +-- During phase 2 of the algorithm, when a shared subterm is replaced with a 'VarSharing' +-- node it may already contain another 'VarSharing' node. In this case it is said to have +-- a dependency on it. The 'DepGroup' data type captures these dependencies. +-- +-- Invariant (for DepGroup): If one shared term 's' is itself a subterm of another shared +-- term 't', then edge '(t,s)' must appear in the DepGroup. +-- +-- Invariant (for NodeCounts): The 'DepGroup's do not overlap. That is, for any two +-- 'DepGroup's, 'dg1' and 'dg2' the intersection of the nodes of dg1 and dg2 is empty. +-- The order of DepGroups themselves does not matter. +-- +-- +-- To ensure the invariant is preserved over merging node counts from sibling subterms, +-- the function '(+++)' must be used. +-- + +-- We use the 'HashMap' and 'HashSet' data structure of the 'unordered-containers' package +-- because 'StableAccName' does not have an 'Ord' instance. +data DepGroup = DepGroup { nodes :: HashMap StableAccName (StableSharingAcc, Int) + , edges :: HashMap StableAccName (HashSet StableAccName) } + +-- Nicer output for debugging. +instance Show DepGroup where + show dg = printf "DepGroup { nodes = %s, edges = [%s] }" + (show (map snd $ Map.toList (nodes dg))) + (showNodes (Map.toList $ edges dg)) + where + showNodes [] = "" + showNodes [s] = showNodes1 s + showNodes (s:ss) = printf "%s, %s" (showNodes1 s) (showNodes ss) + showNodes1 :: (StableAccName, HashSet StableAccName) -> String + showNodes1 (sa,set) = printf "(%s, %s)" (show sa) (show (Set.toList set)) + +newtype NodeCounts = NodeCounts [DepGroup] deriving Show + +emptyDepGroup :: DepGroup +emptyDepGroup = DepGroup { nodes = Map.empty, edges = Map.empty } + -- --- To ensure the invariant is preserved over merging node counts from sibling subterms, the --- function '(+++)' must be used. +-- Merges two dependency groups. -- -newtype NodeCounts = NodeCounts [(StableSharingAcc, Int)] - deriving Show +mergeDepGroup :: DepGroup -> DepGroup -> DepGroup +mergeDepGroup dg1 dg2 = DepGroup newNodes newEdges + where + newNodes = Map.foldlWithKey' (\m k v -> Map.insertWith updateCount k v m) + (nodes dg1) (nodes dg2) + updateCount (sa1, count1) (sa2, count2) = (sa1 `pickNoneVar` sa2, count1 + count2) + newEdges = Map.foldlWithKey' (\m k v -> Map.insertWith Set.union k v m) + (edges dg1) (edges dg2) + + +depGroupInsertNode :: StableSharingAcc -> DepGroup -> DepGroup +depGroupInsertNode sa dg = dg { nodes = newNodes } + where + san = stableAccNameOf sa + newNode = case Map.lookup san (nodes dg) of + Just (sa', count) -> (sa `pickNoneVar` sa', 1 + count) + Nothing -> (sa, 1) + newNodes = Map.insert san newNode (nodes dg) + +-- Precondition: The node must already be a member +depGroupInsertEdge :: StableSharingAcc -> StableAccName -> DepGroup -> DepGroup +depGroupInsertEdge src tgtSA dg = dg { edges = newEdges } + where + srcSA = stableAccNameOf src + newEdges = Map.insertWith Set.union srcSA (Set.singleton tgtSA) (edges dg) + +pickNoneVar :: StableSharingAcc -> StableSharingAcc -> StableSharingAcc +(StableSharingAcc _ (VarSharing _)) `pickNoneVar` sa2 = sa2 +sa1 `pickNoneVar` _sa2 = sa1 + +stableAccNameOf :: StableSharingAcc -> StableAccName +stableAccNameOf (StableSharingAcc sn _) = StableAccName sn -- Empty node counts -- @@ -862,8 +937,24 @@ noNodeCounts = NodeCounts [] -- Singleton node counts -- -nodeCount :: (StableSharingAcc, Int) -> NodeCounts -nodeCount nc = NodeCounts [nc] +-- Merges all the 'DepGroup's in 'subCounts', add the node 'stableSharingAcc' with a +-- sharing count of 1, and also adds edges from this node to all the nodes in the merged +-- dependency group. +-- +nodeCount :: StableSharingAcc -> NodeCounts -> NodeCounts +nodeCount stableSharingAcc (NodeCounts subCounts) = + NodeCounts $ [depGroup] + where + san = stableAccNameOf stableSharingAcc + mergedDepGroup :: DepGroup + mergedDepGroup = foldl mergeDepGroup emptyDepGroup subCounts + depGroup :: DepGroup + -- Adds an edge for each node in 'mergedDepGroup'. This is probably + -- overkill. It would only be necessary to add an edge to the "root" node + -- of each 'DepGroup' in 'subCounts'. + depGroup = Set.foldr (depGroupInsertEdge stableSharingAcc) + (depGroupInsertNode stableSharingAcc mergedDepGroup) + (keysSet (nodes mergedDepGroup)) -- Combine node counts that belong to the same node. -- @@ -874,19 +965,62 @@ nodeCount nc = NodeCounts [nc] -- nesting depth, but doesn't seem worthwhile as the arguments are expected to be fairly short. -- Change if profiling suggests that this function is a bottleneck. -- + (+++) :: NodeCounts -> NodeCounts -> NodeCounts -NodeCounts us +++ NodeCounts vs = NodeCounts $ merge us vs +NodeCounts us +++ NodeCounts vs = NodeCounts $ + let result = merge us vs + in if length us > 0 && length vs > 0 + then trace (printf " %s\n `merge`\n %s\n ==\n %s" + (show us) (show vs) (show result)) result + else result where + merge :: [DepGroup] -> [DepGroup] -> [DepGroup] merge [] ys = ys merge xs [] = xs - merge xs@(x@(sa1, count1) : xs') ys@(y@(sa2, count2) : ys') - | sa1 == sa2 = (sa1 `pickNoneVar` sa2, count1 + count2) : merge xs' ys' - | sa1 `notElem` map fst ys' = x : merge xs' ys - | sa2 `notElem` map fst xs' = y : merge xs ys' - | otherwise = INTERNAL_ERROR(error) "(+++)" "Precondition violated" + merge (xs:xss) yss = mergeInto xs (merge xss yss) + + mergeInto :: DepGroup -> [DepGroup] -> [DepGroup] + mergeInto xs [] = [xs] + mergeInto xs (ys:yss) + | overlap xs ys = mergeInto (mergeDepGroup xs ys) yss + | otherwise = ys:mergeInto xs yss + + -- Note: This is quadratic in complexity. HashSet does not have an 'intersection' method. + overlap :: DepGroup -> DepGroup -> Bool + overlap dg1 dg2 = overlap' (Map.keys $ nodes dg1) (Map.keys $ nodes dg2) + where + overlap' [] _ = False + overlap' (x:xs) ys = x `elem` ys || overlap' xs ys - (StableSharingAcc _ (VarSharing _)) `pickNoneVar` sa2 = sa2 - sa1 `pickNoneVar` _sa2 = sa1 +keysSet :: (Eq k, Hashable k) => HashMap k a -> HashSet k +keysSet = Set.fromList . Map.keys + +type TopoSortState = (HashSet StableAccName, [(StableSharingAcc, Int)]) + +-- +-- Returns the 'StableSharingAcc's for a 'DepGroup' along with the sharing count +-- in reverse-binding order. +-- +stableSharingAccsForDepGroup :: DepGroup -> [(StableSharingAcc, Int)] +stableSharingAccsForDepGroup dg = topoSort (Map.keys $ nodes dg) Set.empty [] + where + -- topological sort + topoSort :: [StableAccName] -> HashSet StableAccName -> [(StableSharingAcc, Int)] + -> [(StableSharingAcc, Int)] + topoSort [] _ accum = accum + topoSort (san:sans) visited accum = + let (visited', result) = visit san (visited, accum) + in topoSort sans visited' accum ++ result + visit :: StableAccName -> TopoSortState -> TopoSortState + visit san this@(visited, accum) + | Set.member san visited = this + | otherwise = + let visited' = Set.insert san visited + element = fromJust $ Map.lookup san $ nodes dg + (visited'', accum'') = case Map.lookup san (edges dg) of + Just succs -> Set.foldr visit (visited', accum) succs + Nothing -> (visited', accum) + in (visited'', element : accum'') -- Determine the scopes of all variables representing shared subterms (Phase Two) in a bottom-up -- sweep. The first argument determines whether array computations are floated out of expressions @@ -901,7 +1035,10 @@ determineScopes floatOutAcc occMap rootAcc = fst $ scopesAcc rootAcc scopesAcc (LetSharing _ _) = INTERNAL_ERROR(error) "determineScopes: scopes" "unexpected 'LetSharing'" scopesAcc sharingAcc@(VarSharing sn) - = (VarSharing sn, nodeCount (StableSharingAcc sn sharingAcc, 1)) + = trace debugMsg (VarSharing sn, newCount) + where + newCount = nodeCount (StableSharingAcc sn sharingAcc) noNodeCounts + debugMsg = printf "%s: (VarSharing) %s" (show $ StableAccName sn) (show newCount) scopesAcc (AccSharing sn pacc) = case pacc of Atag i -> reconstruct (Atag i) noNodeCounts @@ -1055,21 +1192,34 @@ determineScopes floatOutAcc occMap rootAcc = fst $ scopesAcc rootAcc -- reconstruct :: Arrays arrs => PreAcc SharingAcc arrs -> NodeCounts -> (SharingAcc arrs, NodeCounts) - reconstruct newAcc subCount - | occCount > 1 = ( VarSharing sn - , nodeCount (StableSharingAcc sn sharingAcc, 1) +++ newCount) - | otherwise = (sharingAcc, newCount) + reconstruct newAcc subCount = trace debugMsg reconstruct' where + debugMsg = printf ("%s: bindHere = %s\n subCount = %s" ++ + "\n newCount = %s" ++ + "\n newNodeCounts = %s") + (show (StableAccName sn)) + (show $ bindHere) + (show $ subCount) + (show $ newCount) + (if occCount > 1 then show + (nodeCount (StableSharingAcc sn sharingAcc) newCount) else "---") + reconstruct' + | occCount > 1 = (VarSharing sn, nodeCount (StableSharingAcc sn sharingAcc) newCount) + | otherwise = (sharingAcc, newCount) -- Determine the bindings that need to be attached to the current node... (newCount, bindHere) = filterCompleted subCount - -- ...and wrap them in 'LetSharing' constructors lets = foldl (flip (.)) id . map LetSharing $ bindHere sharingAcc = lets $ AccSharing sn newAcc - -- Extract nodes that have a complete node count (i.e., their node count is equal to the - -- number of occurences of that node in the overall expression) => nodes with a completed - -- node count should be let bound at the currently processed node. + -- Extract nodes that have a complete node count (i.e., their node count is equal + -- to the number of occurences of that node in the overall expression) => nodes + -- with a completed node count should be let bound at the currently processed + -- node. + -- + -- Nodes are extracted on a per dependency group ('DepGroup') basis. If all the + -- nodes in a dependency group have a sharing count equal to their occurrence + -- counts then they are filtered out. -- filterCompleted :: NodeCounts -> (NodeCounts, [StableSharingAcc]) filterCompleted (NodeCounts counts) @@ -1077,14 +1227,18 @@ determineScopes floatOutAcc occMap rootAcc = fst $ scopesAcc rootAcc in (NodeCounts counts', completed) where fc [] = ([], []) - fc (sub@(sa, n):subs) + fc (sub:subs) -- current node is the binding point for the shared node 'sa' - | occCount == n = (subs', sa:bindHere) + | readyToBind sub = + let sas = map fst $ stableSharingAccsForDepGroup sub + in (subs', sas ++ bindHere) -- not a binding point | otherwise = (sub:subs', bindHere) where - occCount = lookupWithSharingAcc occMap sa (subs', bindHere) = fc subs + readyToBind :: DepGroup -> Bool + readyToBind dg = all (\(sa,n) -> lookupWithSharingAcc occMap sa == n) + (stableSharingAccsForDepGroup dg) scopesExp :: forall arrs. SharingExp arrs -> (SharingExp arrs, NodeCounts) scopesExp pacc @@ -1151,10 +1305,13 @@ determineScopes floatOutAcc occMap rootAcc = fst $ scopesAcc rootAcc -> (SharingExp b, NodeCounts) maybeFloatOutAcc c acc@(VarSharing _) accCount = (c acc, accCount) -- nothing to float out maybeFloatOutAcc c acc accCount - | floatOutAcc = (c var, nodeCount (stableAcc, 1) +++ accCount) + | floatOutAcc = trace debugMsg (c var, floatedCount) | otherwise = (c acc, accCount) where (var, stableAcc) = abstract acc id + floatedCount = nodeCount stableAcc accCount + debugMsg = printf "Floating out %s:\n nodeCount = %s" + (show stableAcc) (show $ floatedCount) abstract :: SharingAcc a -> (SharingAcc a -> SharingAcc a) -> (SharingAcc a, StableSharingAcc) diff --git a/accelerate-examples/src/Main.hs b/accelerate-examples/src/Main.hs index a92aaca90..5667ba1f6 100644 --- a/accelerate-examples/src/Main.hs +++ b/accelerate-examples/src/Main.hs @@ -72,4 +72,3 @@ main = do valid <- runVerify config tests -- unless (null valid || cfgVerify config) $ runTiming config valid - diff --git a/accelerate-examples/src/Test.hs b/accelerate-examples/src/Test.hs index 3445c4e2e..d7c378e48 100644 --- a/accelerate-examples/src/Test.hs +++ b/accelerate-examples/src/Test.hs @@ -28,6 +28,7 @@ import qualified BlockCopy import qualified Canny import qualified IntegralImage +import qualified SharingRecovery -- friends import Util @@ -108,6 +109,7 @@ data Test allTests :: Config -> IO [Test] allTests cfg = sequence' [ + -- primitive functions mkTest "map-abs" "absolute value of each element" $ Map.run "abs" n , mkTest "map-plus" "add a constant to each element" $ Map.run "plus" n @@ -144,7 +146,16 @@ allTests cfg = sequence' , mkTest "slices" "replicate (Z:.All:.All:.2)" $ SliceExamples.run3 , mkTest "slices" "replicate (Any:.2)" $ SliceExamples.run4 , mkTest "slices" "replicate (Z:.2:.2:.2)" $ SliceExamples.run5 - + -- + , mkIO "sharing-recovery" "simple" $ return (show SharingRecovery.simple) + , mkIO "sharing-recovery" "orderFail" $ return (show SharingRecovery.orderFail) + , mkIO "sharing-recovery" "testSort" $ return (show SharingRecovery.testSort) + , mkIO "sharing-recovery" "muchSharing" $ return (show $ SharingRecovery.muchSharing 20) + , mkIO "sharing-recovery" "bfsFail" $ return (show SharingRecovery.bfsFail) + , mkIO "sharing-recovery" "twoLetsSameLevel" $ return (show SharingRecovery.twoLetsSameLevel) + , mkIO "sharing-recovery" "twoLetsSameLevel2" $ return (show SharingRecovery.twoLetsSameLevel2) + , mkIO "sharing-recovery" "noLetAtTop" $ return (show SharingRecovery.noLetAtTop) + , mkIO "sharing-recovery" "noLetAtTop2" $ return (show SharingRecovery.noLetAtTop2) ] where n = cfgElements cfg @@ -158,9 +169,8 @@ allTests cfg = sequence' acc <- unsafeInterleaveIO builder return $ TestNoRef name desc acc -#ifdef ACCELERATE_IO mkIO name desc act = return $ TestIO name desc act -#endif + -- How to evaluate Accelerate programs with the chosen backend? -- @@ -193,7 +203,7 @@ verifyTest cfg test = do $ map (\(i,v) -> ">>> " ++ shows i " : " ++ show v) errs TestNoRef _ _ acc -> return $ run acc `seq` Ok - TestIO _ _ act -> act >> return Ok + TestIO _ _ act -> act >>= \v -> v `seq` return Ok -- unless quiet $ putStrLn (show result) return result diff --git a/accelerate-examples/tests/simple/SharingRecovery.hs b/accelerate-examples/tests/simple/SharingRecovery.hs new file mode 100644 index 000000000..3b5c2f31d --- /dev/null +++ b/accelerate-examples/tests/simple/SharingRecovery.hs @@ -0,0 +1,165 @@ +{-# LANGUAGE TypeOperators, ScopedTypeVariables #-} + + +-- +-- Some tests to make sure that sharing recovery is working. +-- +module SharingRecovery where + +import Prelude hiding (zip3) + +import Data.Array.Accelerate as Acc + + +mkArray :: Int -> Acc (Array DIM1 Int) +mkArray n = use $ fromList (Z:.1) [n] + +muchSharing :: Int -> Acc (Array DIM1 Int) +muchSharing 0 = (mkArray 0) +muchSharing n = Acc.map (\_ -> newArr ! (lift (Z:.(0::Int))) + + newArr ! (lift (Z:.(1::Int)))) (mkArray n) + where + newArr = muchSharing (n-1) + +idx :: Int -> Exp DIM1 +idx i = lift (Z:.i) + +bfsFail :: Acc (Array DIM1 Int) +bfsFail = Acc.map (\x -> (map2 ! (idx 1)) + (map1 ! (idx 2)) + x) arr + where + map1 :: Acc (Array DIM1 Int) + map1 = Acc.map (\y -> (map2 ! (idx 3)) + y) arr + map2 :: Acc (Array DIM1 Int) + map2 = Acc.map (\z -> z + 1) arr + arr :: Acc (Array DIM1 Int) + arr = mkArray 666 + +twoLetsSameLevel :: Acc (Array DIM1 Int) +twoLetsSameLevel = + let arr1 = mkArray 1 + in let arr2 = mkArray 2 + in Acc.map (\_ -> arr1!(idx 1) + arr1!(idx 2) + arr2!(idx 3) + arr2!(idx 4)) (mkArray 3) + +twoLetsSameLevel2 :: Acc (Array DIM1 Int) +twoLetsSameLevel2 = + let arr2 = mkArray 2 + in let arr1 = mkArray 1 + in Acc.map (\_ -> arr1!(idx 1) + arr1!(idx 2) + arr2!(idx 3) + arr2!(idx 4)) (mkArray 3) + +-- +-- These two programs test that lets can be introduced not just at the top of a AST +-- but in intermediate nodes. +-- +noLetAtTop :: Acc (Array DIM1 Int) +noLetAtTop = Acc.map (\x -> x + 1) bfsFail + +noLetAtTop2 :: Acc (Array DIM1 Int) +noLetAtTop2 = Acc.map (\x -> x + 2) $ Acc.map (\x -> x + 1) bfsFail + +-- +-- +-- +simple :: Acc (Array DIM1 (Int,Int)) +simple = Acc.map (\_ -> a ! (idx 1)) d + where + c = use $ Acc.fromList (Z :. 3) [1..] + d = Acc.map (+1) c + a = Acc.zip d c + +-------------------------------------------------------------------------------- +-- +-- sortKey is a real program that Ben Lever wrote. It has some pretty interesting +-- sharing going on. +-- +sortKey :: (Elt e) + => (Exp e -> Exp Int) -- ^mapping function to produce key array from input array + -> Acc (Vector e) + -> Acc (Vector e) +sortKey keyFun arr = foldl sortOneBit arr (Prelude.map lift ([0..31] :: [Int])) + where + sortOneBit inArr bitNum = outArr + where + keys = Acc.map keyFun inArr + + bits = Acc.map (\a -> (Acc.testBit a bitNum) ? (1, 0)) keys + bitsInv = Acc.map (\b -> (b ==* 0) ? (1, 0)) bits + + (falses, numZeroes) = Acc.scanl' (+) 0 bitsInv + trues = Acc.map (\x -> (Acc.the numZeroes) + (Acc.fst x) - (Acc.snd x)) $ + Acc.zip ixs falses + + dstIxs = Acc.map (\x -> let (b, t, f) = unlift x in (b ==* (constant (0::Int))) ? (f, t)) $ + zip3 bits trues falses + outArr = scatter dstIxs inArr inArr -- just use input as default array + --(we're writing over everything anyway) + -- + ixs = enumeratedArray (shape arr) + +-- | Copy elements from source array to destination array according to a map. For +-- example: +-- +-- default = [0, 0, 0, 0, 0, 0, 0, 0, 0] +-- map = [1, 3, 7, 2, 5, 8] +-- input = [1, 9, 6, 4, 4, 2, 5] +-- +-- output = [0, 1, 4, 9, 0, 4, 0, 6, 2] +-- +-- Note if the same index appears in the map more than once, the result is +-- undefined. The map vector cannot be larger than the input vector. +-- +scatter :: (Elt e) + => Acc (Vector Int) -- ^map + -> Acc (Vector e) -- ^default + -> Acc (Vector e) -- ^input + -> Acc (Vector e) -- ^output +scatter mapV defaultV inputV = Acc.permute (const) defaultV pF inputV + where + pF ix = lift (Z :. (mapV ! ix)) + + +-- | Create an array where each element is the value of its corresponding row-major +-- index. +-- +--enumeratedArray :: (Shape sh) => Exp sh -> Acc (Array sh Int) +--enumeratedArray sh = Acc.reshape sh +-- $ Acc.generate (index1 $ shapeSize sh) unindex1 + +enumeratedArray :: Exp DIM1 -> Acc (Array DIM1 Int) +enumeratedArray sh = Acc.generate sh unindex1 + +zip3 :: forall sh e1 e2 e3. (Shape sh, Elt e1, Elt e2, Elt e3) + => Acc (Array sh e1) + -> Acc (Array sh e2) + -> Acc (Array sh e3) + -> Acc (Array sh (e1, e2, e3)) +zip3 as bs cs = Acc.zipWith (\a bc -> let (b, c) = unlift bc :: (Exp e2, Exp e3) + in lift (a, b, c)) as $ Acc.zip bs cs + +unzip3 :: forall sh. forall e1. forall e2. forall e3. (Shape sh, Elt e1, Elt e2, Elt e3) + => Acc (Array sh (e1, e2, e3)) + -> (Acc (Array sh e1), Acc (Array sh e2), Acc (Array sh e3)) +unzip3 abcs = (as, bs, cs) + where + (bs, cs) = Acc.unzip bcs + (as, bcs) = Acc.unzip + $ Acc.map (\abc -> let (a, b, c) = unlift abc :: (Exp e1, Exp e2, Exp e3) + in lift (a, lift (b, c))) abcs + +testSort = sortKey id $ use $ fromList (Z:.10) [9,8,7,6,5,4,3,2,1,0] + +---------------------------------------------------------------------- + +-- +-- map1 has children map3 and map2. +-- map2 has child map3. +-- Back when we still used a list for the NodeCounts data structure this mean that +-- you would be merging [1,3,2] with [2,3] which violated precondition of (+++). +-- This tests that the new algorithm works just fine on this. +-- +orderFail :: Acc (Array DIM1 Int) +orderFail = Acc.map (\_ -> map1 ! (idx 1) + map2 ! (idx 1)) arr + where + map1 = Acc.map (\_ -> map3 ! (idx 1) + map2 ! (idx 2)) arr + map2 = Acc.map (\_ -> map3 ! (idx 3)) arr + map3 = Acc.map (+1) arr + arr = mkArray 42 \ No newline at end of file diff --git a/accelerate.cabal b/accelerate.cabal index 8803b887a..43baaec56 100644 --- a/accelerate.cabal +++ b/accelerate.cabal @@ -120,7 +120,9 @@ Library directory >= 1.0 && < 1.2, ghc-prim == 0.2.*, mtl == 2.0.*, - pretty == 1.0.* + pretty == 1.0.*, + hashable >= 1.1.1.0, + unordered-containers >= 0.1.3.0 Include-Dirs: include