Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

A simplification, using a trick in the C-quasiquoter (the C++ typenam…

…e keyword)
  • Loading branch information...
commit ec3cb5cbaaba4b64bbb8d8cd5e3f7866cca6cfbd 1 parent a25e7b9
@dybber dybber authored
View
22 Data/Array/Accelerate/OpenCL/CodeGen.hs
@@ -37,7 +37,7 @@ import Data.Array.Accelerate.Pretty ()
import Data.Array.Accelerate.Analysis.Type
import Data.Array.Accelerate.Analysis.Shape
--import Data.Array.Accelerate.Analysis.Stencil
---import Data.Array.Accelerate.Array.Representation
+import Data.Array.Accelerate.Array.Representation
import qualified Data.Array.Accelerate.Array.Sugar as Sugar
import qualified Foreign.Storable as F
@@ -114,16 +114,16 @@ codeGenAcc acc vars =
ZipWith f a b -> mkZipWith (codeGenAccTypeDim acc) (codeGenAccTypeDim a) (codeGenAccTypeDim b) (codeGenFun f)
Permute f _ g a -> mkPermute (codeGenAccType a) (accDim acc) (accDim a) (codeGenFun f) (codeGenFun g)
Backpermute _ f a -> mkBackpermute (codeGenAccType a) (accDim acc) (accDim a) (codeGenFun f)
- -- Replicate sl _ a ->
- -- let dimSl = accDim a
- -- dimOut = accDim acc
- -- --
- -- extend :: SliceIndex slix sl co dim -> Int -> [CExpr]
- -- extend (SliceNil) _ = []
- -- extend (SliceAll sliceIdx) n = mkPrj dimOut "dim" n : extend sliceIdx (n+1)
- -- extend (SliceFixed sliceIdx) n = extend sliceIdx (n+1)
- -- in
- -- mkReplicate (codeGenAccType a) dimSl dimOut . reverse $ extend sl 0
+ Replicate sl _ a ->
+ let dimSl = accDim a
+ dimOut = accDim acc
+ --
+ extend :: SliceIndex slix sl co dim -> Int -> [C.Exp]
+ extend (SliceNil) _ = []
+ extend (SliceAll sliceIdx) n = mkPrj dimOut "dim" n : extend sliceIdx (n+1)
+ extend (SliceFixed sliceIdx) n = extend sliceIdx (n+1)
+ in
+ mkReplicate (codeGenAccType a) dimSl dimOut . seqexps . reverse $ extend sl 0
-- -- Index sl a slix ->
-- -- let dimCo = length (codeGenExpType slix)
View
106 Data/Array/Accelerate/OpenCL/CodeGen/Skeleton.hs
@@ -46,21 +46,21 @@ import Data.Array.Accelerate.OpenCL.CodeGen.Monad
mkGenerate :: ([C.Type],Int) -> C.Exp -> CUTranslSkel
mkGenerate (tyOut, dimOut) apply = runCGM $ do
d_out <- mkOutputTuple tyOut
- shape_out <- mkShape "DimOut" dimOut
- _ <- mkShape "TyInA" dimOut
+ mkShape "DimOut" dimOut
+ mkShape "TyInA" dimOut
mkApply 1 apply
ps <- getParams
addDefinitions
[cunit|
- __kernel void generate (const $ty:shape_out shOut,
+ __kernel void generate (const typename DimOut shOut,
$params:ps) {
- const $ty:ix n = $id:(size dimOut)(shOut);
- const $ty:ix gridSize = get_global_size(0);
+ const typename Ix n = $id:(size dimOut)(shOut);
+ const typename Ix gridSize = get_global_size(0);
- for ($ty:ix ix = get_global_id(0); ix < n; ix += gridSize) {
- $ty:outType val = apply($id:(fromIndex dimOut)(shOut, ix));
+ for (typename Ix ix = get_global_id(0); ix < n; ix += gridSize) {
+ typename TyOut val = apply($id:(fromIndex dimOut)(shOut, ix));
set(ix, val, $args:d_out);
}
}
@@ -158,25 +158,25 @@ mkZipWith (tyOut,dimOut) (tyInB, dimInB) (tyInA, dimInA) apply =
d_inB <- mkInputTuple "B" tyInB
mkApply 2 apply
- shape_out <- mkShape "DimOut" dimOut
- shape_inB <- mkShape "DimInB" dimInB
- shape_inA <- mkShape "DimInA" dimInA
+ mkShape "DimOut" dimOut
+ mkShape "DimInB" dimInB
+ mkShape "DimInA" dimInA
ps <- getParams
addDefinitions
[cunit|
- __kernel void zipWith (const $ty:shape_out shOut,
- const $ty:shape_inB shInB,
- const $ty:shape_inA shInA,
+ __kernel void zipWith (const typename DimOut shOut,
+ const typename DimInB shInB,
+ const typename DimInA shInA,
$params:ps) {
- const $ty:ix shapeSize = $id:(size dimOut)(shOut);
- const $ty:ix gridSize = get_global_size(0);
+ const typename Ix shapeSize = $id:(size dimOut)(shOut);
+ const typename Ix gridSize = get_global_size(0);
- for ($ty:ix ix = get_global_id(0); ix < shapeSize; ix += gridSize) {
- $ty:ix iA = $id:(toIndex dimInB)(shInB, $id:(fromIndex dimInB)(shOut, ix));
- $ty:ix iB = $id:(toIndex dimInA)(shInA, $id:(fromIndex dimInA)(shOut, ix));
+ for (typename Ix ix = get_global_id(0); ix < shapeSize; ix += gridSize) {
+ typename Ix iA = $id:(toIndex dimInB)(shInB, $id:(fromIndex dimInB)(shOut, ix));
+ typename Ix iB = $id:(toIndex dimInA)(shInA, $id:(fromIndex dimInA)(shOut, ix));
- $ty:outType val = apply(getB(iB, $args:d_inB), getA(iA, $args:d_inA)) ;
+ typename TyOut val = apply(getB(iB, $args:d_inB), getA(iA, $args:d_inA)) ;
set(ix, val, $args:d_out) ;
}
}
@@ -281,8 +281,8 @@ mkZipWith (tyOut,dimOut) (tyInB, dimInB) (tyInA, dimInA) apply =
mkPermute :: [C.Type] -> Int -> Int -> C.Exp -> C.Exp -> CUTranslSkel
mkPermute ty dimOut dimInA combinefn indexfn = runCGM $ do
(d_out, d_inA : _) <- mkTupleTypeAsc 2 ty
- shape_out <- mkShape "DimOut" dimOut
- shape_inA <- mkShape "DimInA" dimInA
+ mkShape "DimOut" dimOut
+ mkShape "DimInA" dimInA
mkApply 2 combinefn
mkProject Forward indexfn
@@ -290,21 +290,21 @@ mkPermute ty dimOut dimInA combinefn indexfn = runCGM $ do
ps <- getParams
addDefinitions
[cunit|
- __kernel void permute (const $ty:shape_out shOut,
- const $ty:shape_inA shInA,
+ __kernel void permute (const typename DimOut shOut,
+ const typename DimInA shInA,
$params:ps) {
- const $ty:ix shapeSize = $id:(size dimInA)(shInA);
- const $ty:ix gridSize = get_global_size(0);
+ const typename Ix shapeSize = $id:(size dimInA)(shInA);
+ const typename Ix gridSize = get_global_size(0);
- for ($ty:ix ix = get_global_id(0); ix < shapeSize; ix += gridSize) {
- $ty:shape_inA src = $id:(fromIndex dimInA)(shIn0, ix);
- $ty:shape_out dst = project(src);
+ for (typename Ix ix = get_global_id(0); ix < shapeSize; ix += gridSize) {
+ typename DimInA src = $id:(fromIndex dimInA)(shIn0, ix);
+ typename DimOut dst = project(src);
if (!ignore(dst)) {
- $ty:ix j = $id:(toIndex dimOut)(shOut, dst);
+ typename Ix j = $id:(toIndex dimOut)(shOut, dst);
- $ty:outType val = apply(getA(j, $args:d_out),
- getA(ix, $args:d_inA)) ;
+ typename TyOut val = apply(getA(j, $args:d_out),
+ getA(ix, $args:d_inA)) ;
set(j, val, $args:d_out) ;
}
}
@@ -315,25 +315,25 @@ mkPermute ty dimOut dimInA combinefn indexfn = runCGM $ do
mkBackpermute :: [C.Type] -> Int -> Int -> C.Exp -> CUTranslSkel
mkBackpermute ty dimOut dimInA indexFn = runCGM $ do
(d_out, d_inA : _) <- mkTupleTypeAsc 1 ty
- shape_out <- mkShape "DimOut" dimOut
- shape_inA <- mkShape "DimInA" dimInA
+ mkShape "DimOut" dimOut
+ mkShape "DimInA" dimInA
mkProject Backward indexFn
ps <- getParams
addDefinitions
[cunit|
- __kernel void backpermute (const $ty:shape_out shOut,
- const $ty:shape_inA shInA,
- $params:ps) {
- const $ty:ix shapeSize = $id:(size dimInA)(shInA);
- const $ty:ix gridSize = get_global_size(0);
+ __kernel void backpermute (const typename DimOut shOut,
+ const typename DimInA shInA,
+ $params:ps) {
+ const typename Ix shapeSize = $id:(size dimInA)(shInA);
+ const typename Ix gridSize = get_global_size(0);
- for ($ty:ix ix = get_global_id(0); ix < shapeSize; ix += gridSize) {
- $ty:shape_out src = $id:(fromIndex dimOut)(shOut, ix);
- $ty:shape_inA src = project(dst);
+ for (typename Ix ix = get_global_id(0); ix < shapeSize; ix += gridSize) {
+ typename DimOut dst = $id:(fromIndex dimOut)(shOut, ix);
+ typename DimInA src = project(dst);
- $ty:ix j = $id:(toIndex dimInA)(shInA, dst);
+ typename Ix j = $id:(toIndex dimInA)(shInA, dst);
set(ix, getA(j, $args:d_inA), $args:d_out) ;
}
}
@@ -358,25 +358,25 @@ mkBackpermute ty dimOut dimInA indexFn = runCGM $ do
mkReplicate :: [C.Type] -> Int -> Int -> C.Exp -> CUTranslSkel
mkReplicate ty dimSl dimOut slix = runCGM $ do
(d_out, d_inA : _) <- mkTupleTypeAsc 1 ty
- slice <- mkShape "Slice" dimSl
- slice_dim <- mkShape "SliceDim" dimOut
+ mkShape "Slice" dimSl
+ mkShape "SliceDim" dimOut
mkSliceReplicate slix
ps <- getParams
addDefinitions
[cunit|
- __kernel void replicate (const $ty:slice shOut,
- const $ty:slice_dim shInA,
- $params:ps) {
- const $ty:ix shapeSize = $id:(size dimOut)(sliceDim);
- const $ty:ix gridSize = get_global_size(0);
+ __kernel void replicate (const typename Slice shOut,
+ const typename SliceDim shInA,
+ $params:ps) {
+ const typename Ix shapeSize = $id:(size dimOut)(sliceDim);
+ const typename Ix gridSize = get_global_size(0);
- for ($ty:ix ix = get_global_id(0); ix < shapeSize; ix += gridSize) {
- $ty:slice_dim dst = $id:(fromIndex dimOut)(sliceDim, ix);
- $ty:slice src = sliceIndex(dst);
+ for (typename Ix ix = get_global_id(0); ix < shapeSize; ix += gridSize) {
+ typename SliceDim dst = $id:(fromIndex dimOut)(sliceDim, ix);
+ typename Slice src = sliceIndex(dst);
- $ty:ix j = $id:(toIndex dimSl)(slice, src);
+ typename Ix j = $id:(toIndex dimSl)(slice, src);
set(ix, getA(j, $args:d_inA), $args:d_out) ;
}
}
View
12 Data/Array/Accelerate/OpenCL/CodeGen/Util.hs
@@ -108,14 +108,10 @@ mkTypedef volatile tyname typ | volatile = let typ' = mkVolatile typ
in [cedecl|typedef $ty:typ' $id:tyname;|]
| otherwise = [cedecl|typedef $ty:typ $id:tyname;|]
-mkShape :: String -> Int -> CGM Type
-mkShape name dim = do
- addDefinition typedef
- return $ typename name
- where
- typedef | dim == 0 = [cedecl| typedef void* $id:name; |]
- | dim == 1 = [cedecl| typedef $ty:ix $id:name; |]
- | otherwise = mkStruct name False (replicate dim ix)
+mkShape :: String -> Int -> CGM ()
+mkShape name 0 = addDefinition [cedecl| typedef void* $id:name; |]
+mkShape name 1 = addDefinition [cedecl| typedef $ty:ix $id:name; |]
+mkShape name dim = addDefinition $ mkStruct name False (replicate dim ix)
toIndex :: Int -> String
toIndex dim = "toIndexDIM" ++ show dim
View
2  Data/Array/Accelerate/OpenCL/Compile.hs
@@ -595,7 +595,7 @@ compile table key acc fvar = do
-- Compile in another thread
_ <- liftIO . forkIO $ do
let p = (show $ codeGenAcc acc fvar)
- putStrLn p
+-- putStrLn p
prog <- OpenCL.createProgram ctx p
OpenCL.buildProgram prog (map fst devices) =<< compileFlags
putMVar mvar prog
Please sign in to comment.
Something went wrong with that request. Please try again.