Skip to content
Browse files

New code-generation (state) monad for easier skeleton contruction

  • Loading branch information...
1 parent 83df9e0 commit e2ee1fbab6c2ccc5dc8ff162a72f107975c62648 @dybber dybber committed Sep 1, 2011
View
46 Data/Array/Accelerate/OpenCL/CodeGen/Monad.hs
@@ -0,0 +1,46 @@
+module Data.Array.Accelerate.OpenCL.CodeGen.Monad where
+
+import Control.Monad.State
+
+import Language.C hiding (mkPtr)
+import Language.C.Syntax
+import Language.C.Quote.OpenCL
+
+import Data.Array.Accelerate.OpenCL.CodeGen.Data
+
+data SkelState = SkelState {
+ _definitions :: [Definition]
+ , _params :: [Param]
+ }
+
+emptySkelState = SkelState [] []
+
+type CGM = State SkelState
+
+runCGM :: CGM () -> CUTranslSkel
+runCGM st = CUTranslSkel . _definitions $ execState st emptySkelState
+
+
+-- Setters
+addDefinition :: Definition -> CGM ()
+addDefinition def =
+ modify $ \s -> s {_definitions = def : (_definitions s)}
+
+addParam :: Param -> CGM ()
+addParam param =
+ modify $ \s -> s {_params = param : (_params s)}
+
+addDefinitions :: [Definition] -> CGM ()
+addDefinitions = mapM_ addDefinition
+
+addParams :: [Param] -> CGM ()
+addParams = mapM_ addParam
+
+-- Getters
+getDefinitions :: CGM [Definition]
+getDefinitions = gets _definitions
+
+getParams :: CGM [Param]
+getParams = gets _params
+
+
View
178 Data/Array/Accelerate/OpenCL/CodeGen/Skeleton.hs
@@ -17,7 +17,7 @@ module Data.Array.Accelerate.OpenCL.CodeGen.Skeleton
mkMap, mkZipWith,
-- mkStencil, mkStencil2,
-- mkScanl, mkScanr, mkScanl', mkScanr', mkScanl1, mkScanr1,
--- mkPermute, mkBackpermute, mkIndex, mkReplicate
+ mkPermute --, mkBackpermute, mkIndex, mkReplicate
)
where
@@ -34,6 +34,7 @@ import Data.Symbol
import Data.Array.Accelerate.OpenCL.CodeGen.Data
import Data.Array.Accelerate.OpenCL.CodeGen.Util
import Data.Array.Accelerate.OpenCL.CodeGen.Tuple
+import Data.Array.Accelerate.OpenCL.CodeGen.Monad
--import Data.Array.Accelerate.CUDA.CodeGen.Stencil
@@ -115,67 +116,61 @@ import Data.Array.Accelerate.OpenCL.CodeGen.Tuple
-- Map
-- ---
mkMap :: [C.Type] -> [C.Type] -> C.Exp -> CUTranslSkel
-mkMap tyOut tyIn_A apply = CUTranslSkel $ outputdefs ++ inputdefs ++ [apply'] ++ skel
- where
- (outputdefs, out_params, Set callSet) = mkOutputTuple tyOut
- (inputdefs, in_params, Get callGet) = mkInputTuple "A" tyIn_A
-
- skel :: [Definition]
- skel = [cunit|
- __kernel void map (const int shape, $params:(out_params ++ in_params)) {
- int idx;
- const int gridSize = get_global_size(0);
-
- for(idx = get_global_id(0); idx < shape; idx += gridSize) {
- $ty:(typename "TyInA") val = $exp:(callGet "idx") ;
- $ty:outType new = apply(val) ;
- $exp:(callSet "idx" "new") ;
- }
- }
- |]
-
- apply' :: Definition
- apply' = mkApply 1 apply
-
+mkMap tyOut tyIn_A apply = runCGM $ do
+ Set set <- mkOutputTuple tyOut
+ Get get <- mkInputTuple "A" tyIn_A
+ mkApply 1 apply
+
+ ps <- getParams
+ addDefinitions
+ [cunit|
+ __kernel void map (const int shape, $params:ps) {
+ int idx;
+ const int gridSize = get_global_size(0);
+
+ for(idx = get_global_id(0); idx < shape; idx += gridSize) {
+ $ty:(typename "TyInA") val = $exp:(get "idx") ;
+ $ty:outType new = apply(val) ;
+ $exp:(set "idx" "new") ;
+ }
+ }
+ |]
mkZipWith :: ([C.Type], Int)
-> ([C.Type], Int)
->([C.Type], Int) -> C.Exp -> CUTranslSkel
mkZipWith (tyOut,dimOut) (tyInB, dimInB) (tyInA, dimInA) apply =
- CUTranslSkel $ outputdefs ++ inputdefsA ++ inputdefsB ++
- [sh_out_def, sh_inB_def, sh_inA_def, apply'] ++ skel
- where
- (outputdefs, out_params, Set callSet) = mkOutputTuple tyOut
- (inputdefsA, in_paramsA, Get callGetA) = mkInputTuple "A" tyInA
- (inputdefsB, in_paramsB, Get callGetB) = mkInputTuple "B" tyInB
-
- (sh_out_type, sh_out_def) = mkShape "DimOut" dimOut
- (sh_inB_type, sh_inB_def) = mkShape "DimInB" dimInB
- (sh_inA_type, sh_inA_def) = mkShape "DimInA" dimInA
-
- apply' :: Definition
- apply' = mkApply 2 apply
-
- skel :: [Definition]
- skel = [cunit|
- __kernel void zipWith (const $ty:sh_out_type shOut,
- const $ty:sh_inB_type shInB,
- const $ty:sh_inA_type shInA,
- $params:(out_params ++ in_paramsB ++ in_paramsA)) {
- const $ty:ixType shapeSize = $id:(size dimOut)(shOut);
- const $ty:ixType gridSize = get_global_size(0);
-
- for ($ty:ixType ix = get_global_id(0); ix < shapeSize; ix += gridSize) {
- $ty:ixType iA = $id:(toIndex dimInB)(shInB, $id:(fromIndex dimInB)(shOut, ix));
- $ty:ixType iB = $id:(toIndex dimInA)(shInA, $id:(fromIndex dimInA)(shOut, ix));
-
- $ty:(typename "TyInB") valB = $exp:(callGetB "iB") ;
- $ty:(typename "TyInA") valA = $exp:(callGetA "iA") ;
- $ty:outType new = apply(valB, valA) ;
- $exp:(callSet "ix" "new") ;
- }
- }
- |]
+ runCGM $ do
+ Set set <- mkOutputTuple tyOut
+ Get getA <- mkInputTuple "A" tyInA
+ Get getB <- mkInputTuple "B" tyInB
+ mkApply 2 apply
+
+ shape_out <- mkShape "DimOut" dimOut
+ shape_inB <- mkShape "DimInB" dimInB
+ shape_inA <- mkShape "DimInA" dimInA
+
+ ps <- getParams
+ addDefinitions
+ [cunit|
+ __kernel void zipWith (const $ty:shape_out shOut,
+ const $ty:shape_inB shInB,
+ const $ty:shape_inA shInA,
+ $params:ps) {
+ const $ty:ix shapeSize = $id:(size dimOut)(shOut);
+ const $ty:ix gridSize = get_global_size(0);
+
+ for ($ty:ix ix = get_global_id(0); ix < shapeSize; ix += gridSize) {
+ $ty:ix iA = $id:(toIndex dimInB)(shInB, $id:(fromIndex dimInB)(shOut, ix));
+ $ty:ix iB = $id:(toIndex dimInA)(shInA, $id:(fromIndex dimInA)(shOut, ix));
+
+ $ty:(typename "TyInB") valB = $exp:(getB "iB") ;
+ $ty:(typename "TyInA") valA = $exp:(getA "iA") ;
+ $ty:outType new = apply(valB, valA) ;
+ $exp:(set "ix" "new") ;
+ }
+ }
+ |]
-- -- Stencil
@@ -273,17 +268,41 @@ mkZipWith (tyOut,dimOut) (tyInB, dimInB) (tyInA, dimInA) apply =
-- -- Permutation
-- -- -----------
--- mkPermute :: [CType] -> Int -> Int -> [CExpr] -> [CExpr] -> CUTranslSkel
--- mkPermute ty dimOut dimIn0 combinefn indexfn = CUTranslSkel code [] skel
--- where
--- skel = "permute.inl"
--- code = CTranslUnit
--- ( mkTupleTypeAsc 2 ty ++
--- [ mkDim "DimOut" dimOut
--- , mkDim "DimIn0" dimIn0
--- , mkProject Forward indexfn
--- , mkApply 2 combinefn ])
--- (mkNodeInfo (initPos skel) (Name 0))
+mkPermute :: [C.Type] -> Int -> Int -> C.Exp -> C.Exp -> CUTranslSkel
+mkPermute ty dimOut dimInA combinefn indexfn = runCGM $ do
+ (Set set : Get get : _) <- mkTupleTypeAsc 2 ty
+ shape_out <- mkShape "DimOut" dimOut
+ shape_inA <- mkShape "DimInA" dimInA
+
+ mkApply 2 combinefn
+ mkProject Forward indexfn
+
+ ps <- getParams
+ addDefinitions
+ [cunit|
+ __kernel void permute (const $ty:shape_out shOut,
+ const $ty:shape_inA shInA,
+ $params:ps) {
+ const $ty:ix shapeSize = $id:(size dimInA)(shIn0);
+ const $ty:ix gridSize = get_global_size(0);
+
+ for ($ty:ix ix = get_global_id(0); ix < shapeSize; ix += gridSize) {
+ $ty:shape_inA src = $id:(fromIndex dimInA)(shIn0, ix);
+ $ty:shape_out dst = project(src);
+
+ if (!ignore(dst)) {
+ $ty:ix j = $id:(toIndex dimOut)(shOut, dst);
+
+ $ty:(typename "TyOut") valB = $exp:(get "j") ;
+ $ty:(typename "TyInA") valA = $exp:(get "ix") ;
+ $ty:outType new = apply(valB, valA) ;
+ $exp:(set "j" "new") ;
+
+ //set(d_out, j, apply(get0(d_in0, ix), get0(d_out, j)));
+ }
+ }
+ }
+ |]
-- mkBackpermute :: [CType] -> Int -> Int -> [CExpr] -> CUTranslSkel
-- mkBackpermute ty dimOut dimIn0 indexFn = CUTranslSkel code [] skel
@@ -316,6 +335,29 @@ mkZipWith (tyOut,dimOut) (tyInB, dimInB) (tyInA, dimInA) apply =
-- mkReplicate :: [CType] -> Int -> Int -> [CExpr] -> CUTranslSkel
-- mkReplicate ty dimSl dimOut slix = CUTranslSkel code [] skel
-- where
+-- (outputdefs, out_params, Set callSet) = mkOutputTuple ty
+-- (slice_type, slice_def) = mkShape "Slice" dimSl
+-- (slice_dim_type, slice_dim_def) = mkShape "SliceDim" dimOut
+
+-- skel :: [Definition]
+-- skel = [cunit|
+-- __kernel void replicate (
+-- const $ty:slice_type slice,
+-- const $ty:slice_dim_type sliceDim,
+-- __global TyOut *d_out,
+-- __global const TyIn0 *d_in0) {
+
+-- const Ix shapeSize = sizeSliceDim(sliceDim);
+-- const Ix gridSize = get_global_size(0);
+
+-- for (Ix ix = get_global_id(0); ix < shapeSize; ix += gridSize) {
+-- SliceDim dst = fromIndexSliceDim(sliceDim, ix);
+-- Slice src = sliceIndex(dst);
+-- set(d_out, ix, get0(d_in0, toIndexSlice(slice, src)));
+-- }
+-- }
+-- |]
+
-- skel = "replicate.inl"
-- code = CTranslUnit
-- ( mkTupleTypeAsc 1 ty ++
View
82 Data/Array/Accelerate/OpenCL/CodeGen/Tuple.hs
@@ -11,8 +11,9 @@
module Data.Array.Accelerate.OpenCL.CodeGen.Tuple
(
- mkInputTuple, mkOutputTuple, Accessor (..)
- -- mkTupleType, mkTupleTypeAsc, mkTuplePartition
+ mkInputTuple, mkOutputTuple, Accessor (..),
+ mkTupleTypeAsc
+ -- mkTupleType, mkTuplePartition
)
where
@@ -25,37 +26,51 @@ import qualified Language.C.Syntax
import qualified Data.Loc
import qualified Data.Symbol
+import Data.Array.Accelerate.OpenCL.CodeGen.Monad
+import Control.Monad
+
import Data.Array.Accelerate.OpenCL.CodeGen.Util
data Accessor = Get (String -> Exp)
| Set (String -> String -> Exp)
-mkInputTuple :: String -> [Type]-> ([Definition], [Param], Accessor)
+mkInputTuple :: String -> [Type]-> CGM Accessor
mkInputTuple subscript types = mkTupleType (Just subscript) types
-mkOutputTuple :: [Type]-> ([Definition], [Param], Accessor)
+mkOutputTuple :: [Type]-> CGM Accessor
mkOutputTuple types = mkTupleType Nothing types
-mkTupleType :: Maybe String -> [Type] -> ([Definition], [Param], Accessor)
-mkTupleType subscript types = (typedefs ++ struct ++ [accessor], params, accessorCall)
+mkTupleType :: Maybe String -> [Type] -> CGM Accessor
+mkTupleType subscript types = do
+ let n = length types
+ tuple_name = maybe "TyOut" ("TyIn" ++) subscript
+ volatile = isNothing subscript
+ tynames
+ | n > 1 = take n [tuple_name ++ "_" ++ show i | i <- [0..]] -- TyInA_0, TyInA_1, ...
+ | otherwise = [tuple_name]
+
+ addDefinitions $ zipWith (mkTypedef volatile) tynames types
+ accessorCall <- mkParameterList subscript n tynames
+ (maybe mkSet mkGet subscript) n tynames
+ when (n > 1) $ addDefinition (mkStruct tuple_name volatile types)
+ return accessorCall
+
+mkTupleTypeAsc :: Int -> [Type] -> CGM [Accessor]
+mkTupleTypeAsc n types = do
+ accessorOut <- mkOutputTuple types
+ accessorsIn <- mkInputTuples (n-1)
+ return $ accessorOut : accessorsIn
where
- n = length types
- tuple_name = maybe "TyOut" ("TyIn" ++) subscript
- volatile = isNothing subscript
- tynames
- | n > 1 = take n [tuple_name ++ "_" ++ show i | i <- [0..]] -- TyInA_0, TyInA_1, ...
- | otherwise = [tuple_name]
-
- -- typedef float TyInA_0; typedef float TyInA_1; ...
- typedefs = zipWith (mkTypedef volatile) tynames types
- (params, accessorCall) = mkParameterList subscript n tynames
- accessor = (maybe mkSet mkGet subscript) n tynames params
- struct
- | n > 1 = [mkStruct tuple_name volatile types]
- | otherwise = []
-
-mkParameterList :: Maybe String -> Int -> [String] -> ([Param], Accessor)
-mkParameterList subscript n tynames = (params $ zip types' param_names, accessorCall)
+ mkInputTuples 0 = return []
+ mkInputTuples n = do
+ as <- mkInputTuples (n-1)
+ a <- mkInputTuple (show $ n-1) types
+ return $ a : as
+
+mkParameterList :: Maybe String -> Int -> [String] -> CGM Accessor
+mkParameterList subscript n tynames = do
+ addParams $ params (zip types' param_names)
+ return accessorCall
where
param_prefix = maybe "out" ("in" ++) subscript
param_names
@@ -69,8 +84,9 @@ mkParameterList subscript n tynames = (params $ zip types' param_names, accessor
Nothing -> Set $ \idx val -> [cexp|set($id:idx, $id:val, $args:args)|]
Just x -> Get $ \idx -> [cexp|$id:("get" ++ x)($id:idx, $args:args)|]
-mkGet :: String -> Int -> [String] -> [Param] -> (Definition)
-mkGet prj n tynames params =
+mkGet :: String -> Int -> [String] -> CGM ()
+mkGet prj n tynames = do
+ params <- getParams
let name = "get" ++ prj
param_name = "in" ++ prj
returnType = typename $ "TyIn" ++ prj
@@ -79,23 +95,27 @@ mkGet prj n tynames params =
assignments
| n > 1 = zipWith assign [0..] tynames
| otherwise = [ [cstm|val = $id:param_name [idx];|] ]
- in [cedecl|
- inline $ty:returnType $id:name(const $ty:ixType idx, $params:params) {
+
+ addDefinition
+ [cedecl|
+ inline $ty:returnType $id:name(const $ty:ix idx, $params:params) {
$ty:returnType val;
$stms:assignments
return val;
}
|]
-mkSet :: Int -> [String] -> [Param] -> Definition
-mkSet n tynames params =
+mkSet :: Int -> [String] -> CGM ()
+mkSet n tynames = do
+ params <- getParams
let assign i name = let field = 'a' : show i
in [cstm|$id:name [idx] = val.$id:field;|]
assignments
| n > 1 = zipWith assign [0..] tynames
| otherwise = [ [cstm|out[idx] = val;|] ]
- in [cedecl|
- inline void set(const $ty:ixType idx, const $ty:outType val, $params:params) {
+ addDefinition
+ [cedecl|
+ inline void set(const $ty:ix idx, const $ty:outType val, $params:params) {
$stms:assignments
}
|]
View
35 Data/Array/Accelerate/OpenCL/CodeGen/Util.hs
@@ -9,11 +9,13 @@ import Language.C.Quote.OpenCL
import Data.Loc
import Data.Symbol
+import Data.Array.Accelerate.OpenCL.CodeGen.Monad
+
data Direction = Forward | Backward
-- Types
-ixType :: Type
-ixType = typename "Ix"
+ix :: Type
+ix = typename "Ix"
outType :: Type
outType = typename "TyOut"
@@ -26,14 +28,19 @@ outType = typename "TyOut"
mkIdentity :: Exp -> Definition
mkIdentity = mkDeviceFun "identity" (typename "TyOut") []
-mkApply :: Int -> Exp -> Definition
-mkApply argc
- = mkDeviceFun "apply" outType
- $ params $ map (\c -> (typename ("TyIn"++ [c]), 'x' : [c])) $ reverse $ take argc ['A'..]
+mkApply :: Int -> Exp -> CGM ()
+mkApply argc exp
+ = addDefinition $
+ (mkDeviceFun "apply" outType
+ $ params $ map (\c -> (typename ("TyIn"++ [c]), 'x' : [c])) $ reverse $ take argc ['A'..]) exp
-mkProject :: Direction -> Exp -> Definition
-mkProject Forward = mkDeviceFun "project" (typename "DimOut") $ params [(typename "DimIn0","x0")]
-mkProject Backward = mkDeviceFun "project" (typename "DimIn0") $ params [(typename "DimOut","x0")]
+mkProject :: Direction -> Exp -> CGM ()
+mkProject Forward exp =
+ addDefinition $
+ (mkDeviceFun "project" (typename "DimOut") $ params [(typename "DimIn0","x0")]) exp
+mkProject Backward exp =
+ addDefinition $
+ (mkDeviceFun "project" (typename "DimIn0") $ params [(typename "DimOut","x0")]) exp
mkSliceIndex :: Exp -> Definition
mkSliceIndex =
@@ -100,12 +107,14 @@ mkTypedef volatile tyname typ | volatile = let typ' = mkVolatile typ
in [cedecl|typedef $ty:typ' $id:tyname;|]
| otherwise = [cedecl|typedef $ty:typ $id:tyname;|]
-mkShape :: String -> Int -> (Type, Definition)
-mkShape name dim = (typename name, typedef)
+mkShape :: String -> Int -> CGM Type
+mkShape name dim = do
+ addDefinition typedef
+ return $ typename name
where
typedef | dim == 0 = [cedecl| typedef void* $id:name; |]
- | dim == 1 = [cedecl| typedef $ty:ixType $id:name; |]
- | otherwise = mkStruct name False (replicate dim ixType)
+ | dim == 1 = [cedecl| typedef $ty:ix $id:name; |]
+ | otherwise = mkStruct name False (replicate dim ix)
toIndex :: Int -> String
toIndex dim = "toIndexDIM" ++ show dim

0 comments on commit e2ee1fb

Please sign in to comment.
Something went wrong with that request. Please try again.