Permalink
Browse files

Implementation of mkBackpermute and mkReplicate (yet to be tested)

  • Loading branch information...
1 parent 9bef826 commit 2723e86c9bc57972c793073b4bb857d75d87ee6e @dybber dybber committed Sep 1, 2011
Showing with 61 additions and 22 deletions.
  1. +57 −19 Data/Array/Accelerate/OpenCL/CodeGen/Skeleton.hs
  2. +4 −3 Data/Array/Accelerate/OpenCL/CodeGen/Util.hs
@@ -17,7 +17,8 @@ module Data.Array.Accelerate.OpenCL.CodeGen.Skeleton
mkMap, mkZipWith,
-- mkStencil, mkStencil2,
-- mkScanl, mkScanr, mkScanl', mkScanr', mkScanl1, mkScanr1,
- mkPermute --, mkBackpermute, mkIndex, mkReplicate
+ mkPermute, mkBackpermute, mkReplicate
+--, mkIndex
)
where
@@ -298,16 +299,33 @@ mkPermute ty dimOut dimInA combinefn indexfn = runCGM $ do
}
|]
--- mkBackpermute :: [CType] -> Int -> Int -> [CExpr] -> CUTranslSkel
--- mkBackpermute ty dimOut dimIn0 indexFn = CUTranslSkel code [] skel
--- where
--- skel = "backpermute.inl"
--- code = CTranslUnit
--- ( mkTupleTypeAsc 1 ty ++
--- [ mkDim "DimOut" dimOut
--- , mkDim "DimIn0" dimIn0
--- , mkProject Backward indexFn ])
--- (mkNodeInfo (initPos skel) (Name 0))
+
+mkBackpermute :: [C.Type] -> Int -> Int -> C.Exp -> CUTranslSkel
+mkBackpermute ty dimOut dimInA indexFn = runCGM $ do
+ (d_out, d_inA : _) <- mkTupleTypeAsc 1 ty
+ shape_out <- mkShape "DimOut" dimOut
+ shape_inA <- mkShape "DimInA" dimInA
+
+ mkProject Backward indexFn
+
+ ps <- getParams
+ addDefinitions
+ [cunit|
+ __kernel void permute (const $ty:shape_out shOut,
+ const $ty:shape_inA shInA,
+ $params:ps) {
+ const $ty:ix shapeSize = $id:(size dimInA)(shInA);
+ const $ty:ix gridSize = get_global_size(0);
+
+ for ($ty:ix ix = get_global_id(0); ix < shapeSize; ix += gridSize) {
+ $ty:shape_out src = $id:(fromIndex dimOut)(shOut, ix);
+ $ty:shape_inA src = project(dst);
+
+ $ty:ix j = $id:(toIndex dimInA)(shInA, dst);
+ set("ix", getA("j", $args:d_inA), $args:d_out) ;
+ }
+ }
+ |]
-- -- Multidimensional Index and Replicate
@@ -326,7 +344,7 @@ mkPermute ty dimOut dimInA combinefn indexfn = runCGM $ do
-- (mkNodeInfo (initPos skel) (Name 0))
--- mkReplicate :: [CType] -> Int -> Int -> [CExpr] -> CUTranslSkel
+-- mkReplicate :: [C.Type] -> Int -> Int -> [C.Exp] -> CUTranslSkel
-- mkReplicate ty dimSl dimOut slix = CUTranslSkel code [] skel
-- where
-- (outputdefs, out_params, Set callSet) = mkOutputTuple ty
@@ -352,11 +370,31 @@ mkPermute ty dimOut dimInA combinefn indexfn = runCGM $ do
-- }
-- |]
--- skel = "replicate.inl"
--- code = CTranslUnit
--- ( mkTupleTypeAsc 1 ty ++
--- [ mkDim "Slice" dimSl
--- , mkDim "SliceDim" dimOut
--- , mkSliceReplicate slix ])
--- (mkNodeInfo (initPos skel) (Name 0))
+
+mkReplicate :: [C.Type] -> Int -> Int -> C.Exp -> CUTranslSkel
+mkReplicate ty dimSl dimOut slix = runCGM $ do
+ (d_out, d_inA : _) <- mkTupleTypeAsc 1 ty
+ slice <- mkShape "Slice" dimSl
+ slice_dim <- mkShape "SliceDim" dimOut
+
+ mkSliceReplicate slix
+
+ ps <- getParams
+ addDefinitions
+ [cunit|
+ __kernel void permute (const $ty:slice shOut,
+ const $ty:slice_dim shInA,
+ $params:ps) {
+ const $ty:ix shapeSize = $id:(size dimOut)(sliceDim);
+ const $ty:ix gridSize = get_global_size(0);
+
+ for ($ty:ix ix = get_global_id(0); ix < shapeSize; ix += gridSize) {
+ $ty:slice_dim dst = $id:(fromIndex dimOut)(sliceDim, ix);
+ $ty:slice src = sliceIndex(dst);
+
+ $ty:ix j = $id:(toIndex dimSl)(slice, src);
+ set("ix", getA("j", $args:d_inA), $args:d_out) ;
+ }
+ }
+ |]
@@ -46,9 +46,10 @@ mkSliceIndex :: Exp -> Definition
mkSliceIndex =
mkDeviceFun "sliceIndex" (typename "SliceDim") $ params [(typename "Slice","sl"), (typename "CoSlice","co")]
-mkSliceReplicate :: Exp -> Definition
-mkSliceReplicate =
- mkDeviceFun "sliceIndex" (typename "Slice") $ [param (typename "SliceDim") "dim"]
+mkSliceReplicate :: Exp -> CGM ()
+mkSliceReplicate exp =
+ addDefinition $
+ (mkDeviceFun "sliceIndex" (typename "Slice") $ [param (typename "SliceDim") "dim"]) exp
-- Helper functions

0 comments on commit 2723e86

Please sign in to comment.