Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

More work on folding over arrays of tuples

  • Loading branch information...
commit c9732dc7b07529eb894b57cf5a66145ffc507bde 1 parent 12671d7
@dybber dybber authored
View
114 Data/Array/Accelerate/OpenCL/CodeGen/Reduce.hs
@@ -37,10 +37,10 @@ import Data.Array.Accelerate.OpenCL.CodeGen.Monad
-- Exported functions
-- ------------------
mkFold :: ([C.Type],Int) -> C.Exp -> C.Exp -> CUTranslSkel
-mkFold ty identity apply = makeFold False ty (Just identity) apply
+mkFold (ty, dimIn) identity apply = makeFold True (ty, dimIn) (Just identity) apply
mkFold1 :: ([C.Type],Int) -> C.Exp -> CUTranslSkel
-mkFold1 ty apply = makeFold True ty Nothing apply
+mkFold1 (ty, dimIn) apply = makeFold False (ty, dimIn) Nothing apply
@@ -48,34 +48,42 @@ mkFold1 ty apply = makeFold True ty Nothing apply
-- ---------
makeFold :: Bool -> ([C.Type],Int) -> Maybe C.Exp -> C.Exp -> CUTranslSkel
-makeFold inclusive (ty, dim) identity apply = runCGM $ do
+makeFold inclusive (ty, dimIn) identity apply = runCGM $ do
(d_out, d_inA : _) <- mkTupleTypeAsc 1 ty
+ (d_local,local_params) <- mkParameterList Local (Just "local") n tynames
fromMaybe (return ()) (mkIdentity <$> identity)
mkApplyAsc 2 apply
- mkDim "DimInA" dim
- mkDim "DimOut" (dim-1)
- let mkSkel | dim == 1 = mkFoldAllSkel
- | otherwise = mkFoldSkel
- mkSkel d_out d_inA inclusive
-
-mkFoldSkel :: Arguments -> Arguments -> Bool -> CGM ()
+ mkDim "DimInA" dimIn
+ mkDim "DimOut" (dimIn-1)
+ let mkSkel | dimIn == 1 = mkFoldAllSkel
+ | otherwise = mkFoldSkel
+ mkSkel d_out d_inA (d_local, local_params) inclusive
+ where
+ n = length ty
+ tynames
+ | n > 1 = take n ["TyOut" ++ "_" ++ show i | i <- [0..]]
+ | otherwise = ["TyOut"]
+
+
+mkFoldSkel :: Arguments -> Arguments -> (Arguments, [C.Param]) -> Bool -> CGM ()
mkFoldSkel = error "folds for higher dimensions are not yet supported in the OpenCL backend for Accelerate"
-mkFoldAllSkel :: Arguments -> Arguments -> Bool -> CGM ()
-mkFoldAllSkel d_out d_inA inclusive = do
+mkFoldAllSkel :: Arguments -> Arguments -> (Arguments, [C.Param]) -> Bool -> CGM ()
+mkFoldAllSkel d_out d_inA (d_local, local_params) inclusive = do
ps <- getParams
mkHandleSeed d_out inclusive
-
- let include = "#include <reduce.cl>"
- addDefinitions [cunit| $esc:include |]
+
+ mkWarpReduce local_params d_local
+ mkBlockReduce local_params d_local
addDefinitions
[cunit|
__kernel void fold (const typename Ix shape,
- $params:ps) {
+ $params:ps,
+ $params:local_params) {
- volatile __local typename TyOut s_data[100];
+ //volatile __local typename TyOut s_data[100];
//__global ArrOutTy *s_data = partition(s_ptr, get_local_size(0));
/*
@@ -94,8 +102,7 @@ mkFoldAllSkel d_out d_inA inclusive = do
*
* The loop stride of `gridSize' is used to maintain coalescing.
*/
- if (i < shape)
- {
+ if (i < shape) {
sum = getA(i, $args:d_inA);
for (i += gridSize; i < shape; i += gridSize)
sum = apply(sum, getA(i, $args:d_inA));
@@ -105,29 +112,83 @@ mkFoldAllSkel d_out d_inA inclusive = do
* Each thread puts its local sum into shared memory, then threads
* cooperatively reduce the shared array to a single value.
*/
- set_local(tid, sum, s_data);
+ set_local(tid, sum, $args:d_local);
barrier(CLK_LOCAL_MEM_FENCE);
- sum = reduce_block_n(s_data, sum, min(shape, blockSize));
+ sum = reduce_block_n(sum, min(shape, blockSize), $args:d_local);
/*
* Write the results of this block back to global memory. If we are the last
* phase of a recursive multi-block reduction, include the seed element.
*/
- if (tid == 0)
- {
- handleSeed(sum, $args:d_out, $args:d_inA);
+ if (tid == 0) {
+ handleSeed(shape, sum, $args:d_out, $args:d_inA);
}
}
|]
+-- | Cooperatively reduce a single warp's segment of an array to a single value
+mkWarpReduce :: [C.Param] -> Arguments -> CGM ()
+mkWarpReduce ps args = do
+ addDefinitions $
+ [cunit|
+ inline typename TyOut reduce_warp_n (typename TyOut sum,
+ typename Ix n,
+ $params:ps) {
+ int warpSize = 32;
+ const typename Ix tid = get_local_id(0);
+ const typename Ix lane = get_local_id(0) & (warpSize - 1);
+
+ if (n > 16 && lane + 16 < n) { sum = apply(sum, getA_local(tid+16, $args:args)); set_local(tid, sum, $args:args); }
+ if (n > 8 && lane + 8 < n) { sum = apply(sum, getA_local(tid+ 8, $args:args)); set_local(tid, sum, $args:args); }
+ if (n > 4 && lane + 4 < n) { sum = apply(sum, getA_local(tid+ 4, $args:args)); set_local(tid, sum, $args:args); }
+ if (n > 2 && lane + 2 < n) { sum = apply(sum, getA_local(tid+ 2, $args:args)); set_local(tid, sum, $args:args); }
+ if (n > 1 && lane + 1 < n) { sum = apply(sum, getA_local(tid+ 1, $args:args)); }
+ return sum;
+ }
+ |]
+
+-- | Block reduction to a single value
+mkBlockReduce :: [C.Param] -> Arguments -> CGM ()
+mkBlockReduce ps args = do
+ addDefinitions $
+ [cunit|
+ inline typename TyOut reduce_block_n(typename TyOut sum,
+ typename Ix n,
+ $params:ps) {
+ const typename Ix tid = get_local_id(0);
+ if (n > 512) { if (tid < 512 && tid + 512 < n) { sum = apply(sum, getA_local(tid+512, $args:args)); set_local(tid, sum, $args:args); } }
+ barrier(CLK_LOCAL_MEM_FENCE);
+ if (n > 256) { if (tid < 256 && tid + 256 < n) { sum = apply(sum, getA_local(tid+256, $args:args)); set_local(tid, sum, $args:args); } }
+ barrier(CLK_LOCAL_MEM_FENCE);
+ if (n > 128) { if (tid < 128 && tid + 128 < n) { sum = apply(sum, getA_local(tid+128, $args:args)); set_local(tid, sum, $args:args); } }
+ barrier(CLK_LOCAL_MEM_FENCE);
+ if (n > 64) { if (tid < 64 && tid + 64 < n) { sum = apply(sum, getA_local(tid+ 64, $args:args)); set_local(tid, sum, $args:args); } }
+ barrier(CLK_LOCAL_MEM_FENCE);
+ if (n > 32) { if (tid < 32 && tid + 32 < n) { sum = apply(sum, getA_local(tid+ 32, $args:args)); set_local(tid, sum, $args:args); }}
+ barrier(CLK_LOCAL_MEM_FENCE);
+ if (n > 16) { if (tid < 16 && tid + 16 < n) { sum = apply(sum, getA_local(tid+ 16, $args:args)); set_local(tid, sum, $args:args); }}
+ barrier(CLK_LOCAL_MEM_FENCE);
+ if (n > 8) { if (tid < 8 && tid + 8 < n) { sum = apply(sum, getA_local(tid+ 8, $args:args)); set_local(tid, sum, $args:args); }}
+ barrier(CLK_LOCAL_MEM_FENCE);
+ if (n > 4) { if (tid < 4 && tid + 4 < n) { sum = apply(sum, getA_local(tid+ 4, $args:args)); set_local(tid, sum, $args:args); }}
+ barrier(CLK_LOCAL_MEM_FENCE);
+ if (n > 2) { if (tid < 2 && tid + 2 < n) { sum = apply(sum, getA_local(tid+ 2, $args:args)); set_local(tid, sum, $args:args); }}
+ barrier(CLK_LOCAL_MEM_FENCE);
+ if (n > 1) { if (tid == 0 && tid + 1 < n) { sum = apply(sum, getA_local(tid+ 1, $args:args)); }}
+
+ return sum;
+ }
+ |]
+
mkHandleSeed :: Arguments -> Bool -> CGM ()
mkHandleSeed d_out False = do
ps <- getParams
addDefinitions
[cunit|
- inline void handleSeed(typename TyOut sum,
+ inline void handleSeed(const typename Ix shape,
+ typename TyOut sum,
$params:ps)
{
typename Ix blockIdx = (get_global_id(0)-get_local_id(0)) / get_local_size(0);
@@ -138,7 +199,8 @@ mkHandleSeed d_out True = do
ps <- getParams
addDefinitions
[cunit|
- inline void handleSeed(typename TyOut sum,
+ inline void handleSeed(const typename Ix shape,
+ typename TyOut sum,
$params:ps)
{
typename Ix blockIdx = (get_global_id(0)-get_local_id(0)) / get_local_size(0);
View
30 Data/Array/Accelerate/OpenCL/CodeGen/Tuple.hs
@@ -12,7 +12,8 @@
module Data.Array.Accelerate.OpenCL.CodeGen.Tuple
(
mkInputTuple, mkOutputTuple, --Accessor (..),
- mkTupleTypeAsc, Arguments
+ mkTupleTypeAsc, Arguments,
+ mkParameterList
-- mkTupleType, mkTuplePartition
)
where
@@ -54,7 +55,7 @@ mkTupleType subscript types = do
| otherwise = [tuple_name]
addDefinitions $ zipWith (mkTypedef volatile) tynames types
- when (n > 1) $ addDefinition (mkStruct tuple_name volatile types)
+ when (n > 1) $ addDefinition (mkStruct tuple_name volatile $ map typename tynames)
(args,ps) <- mkParameterList Global subscript n tynames
(_,psLocal) <- mkParameterList Local subscript n tynames
(maybe mkSet mkGet subscript) n ps Global
@@ -62,11 +63,30 @@ mkTupleType subscript types = do
addParams ps
return args
+mkInputTypedef :: String -> Int -> CGM Arguments
+mkInputTypedef subscript n = do
+ let tuple_name = "TyIn" ++ subscript
+ tynames_in
+ | n > 1 = take n [tuple_name ++ "_" ++ show i | i <- [0..]] -- TyInA_0, TyInA_1, ...
+ | otherwise = [tuple_name]
+ tynames_out
+ | n > 1 = take n ["TyOut" ++ "_" ++ show i | i <- [0..]] -- TyInA_0, TyInA_1, ...
+ | otherwise = ["TyOut"]
+
+ addDefinitions $ zipWith (mkTypedef True) tynames_in $ map typename tynames_out
+ when (n > 1) $ addDefinition $ mkTypedef False tuple_name (typename "TyOut")
+ (args,ps) <- mkParameterList Global (Just subscript) n tynames_in
+ (_,psLocal) <- mkParameterList Local (Just subscript) n tynames_in
+ mkGet subscript n ps Global
+ mkGet subscript n psLocal Local
+ addParams ps
+ return args
+
mkTupleTypeAsc :: Int -> [Type] -> CGM (Arguments, [Arguments])
-mkTupleTypeAsc n typ = do
+mkTupleTypeAsc cargs typ = do
argsOut <- mkOutputTuple typ
- let names = [ [chr $ ord 'A' + i] | i <- [0..n-1]]
- argsIn <- mapM (flip mkInputTuple typ) names
+ let names = [ [chr $ ord 'A' + i] | i <- [0..cargs-1]]
+ argsIn <- mapM (flip mkInputTypedef $ length typ) names
return $ (argsOut, argsIn)
-- mkLocalAccessors :: Int -> [Type] -> CGM ()
View
7 Data/Array/Accelerate/OpenCL/CodeGen/Util.hs
@@ -21,7 +21,6 @@ outType :: Type
outType = typename "TyOut"
-
-- Common device functions
-- -----------------------
@@ -107,7 +106,11 @@ mkPtr (Type (DeclSpec storage quals typ l0) _ l1) =
mkPtr _ = error "Not a DeclSpec"
data StorageQual = Global | Local
- deriving (Eq, Show)
+ deriving (Eq)
+
+instance Show StorageQual where
+ show Global = "__global"
+ show Local = "__local"
changeStorage :: StorageQual -> Type -> Type
changeStorage stor (Type (DeclSpec storage quals typ l0) _ l1) =
View
5 Data/Array/Accelerate/OpenCL/Execute.hs
@@ -45,7 +45,7 @@ import Control.Monad.Trans
import System.IO.Unsafe
import Foreign.Storable
-import Foreign.Ptr (Ptr)
+import Foreign.Ptr (Ptr, nullPtr)
import qualified Foreign.OpenCL.Bindings as OpenCL
@@ -373,7 +373,7 @@ foldOp c kernel bindings acc aenv (Array sh0 in0)
| dim sh0 == 1 = do
cfg@(_,_,(_,g,_)) <- configure kernel acc (size sh0)
res@(Array _ out) <- newArray (bool c 1 (g > 1)) (toElt (fst sh0,g)) :: CIO (Array (dim:.Int) e)
- dispatch cfg bindings aenv ((((),size sh0),out),in0)
+ dispatch cfg bindings aenv (((((),size sh0),out),in0), OpenCL.LocalArrayArg (undefined :: Int) (size sh0))
freeArray in0
if g > 1 then foldOp c kernel bindings acc aenv res
else return (Array (fst sh0) out)
@@ -694,7 +694,6 @@ primMarshalable((Ptr a))
instance Marshalable (OpenCL.MemObject a) where
marshal x = return [OpenCL.MObjArg x]
-
instance Marshalable OpenCL.KernelArg where
marshal x = return [x]
Please sign in to comment.
Something went wrong with that request. Please try again.