Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
  • 5 commits
  • 4 files changed
  • 0 commit comments
  • 1 contributor
View
145 Data/Array/Accelerate/SimpleAST.hs
@@ -19,24 +19,31 @@ module Data.Array.Accelerate.SimpleAST
payloadLength, applyToPayload, applyToPayload2, applyToPayload3,
-- * Helper routines and predicates:
- var, isIntType, isFloatType
+ var, isIntType, isFloatType,
+ Data.Array.Accelerate.SimpleAST.replicate
)
where
-import Data.Int
-import Data.Word
-import Data.Array.Unboxed as U
-import Foreign.C.Types
-import Pretty (text) -- ghc api
-import Text.PrettyPrint.GenericPretty (Out(doc,docPrec), Generic)
+import Debug.Trace
+import Data.Int
+import Data.Word
+import Data.Array.Unboxed as U
+import qualified Data.Array.Unsafe as Un
+import qualified Data.Array.MArray as MA
+import qualified Data.Array.IO as IA
+import Foreign.C.Types
+import Pretty (text) -- ghc api
+import Text.PrettyPrint.GenericPretty (Out(doc,docPrec), Generic)
+import System.IO.Unsafe (unsafePerformIO)
--------------------------------------------------------------------------------
-- Prelude: Pick a simple representation of variables (interned symbols)
--------------------------------------------------------------------------------
-- Several modules offer this, with varying problems:
----------------------------
-#define USE_STRINGTABLE
+#define USE_SYMBOL
#ifdef USE_STRINGTABLE
-- 'stringtable-atom' package:
+-- I'm getting some segfaults here [2012.05.19];
import StringTable.Atom
var = toAtom
type Var = Atom
@@ -158,7 +165,7 @@ data Const = I Int | I8 Int8 | I16 Int16 | I32 Int32 | I64 Int64
| F Float | D Double | C Char | B Bool
| Tup [Const]
-- Special constants:
- | MinBound | MaxBound | Pi
+-- | MinBound | MaxBound | Pi
-- C types, rather annoying:
| CF CFloat | CD CDouble
| CS CShort | CI CInt | CL CLong | CLL CLLong
@@ -253,12 +260,14 @@ isFloatType ty =
-- order that one would write `(Z :. 3 :. 4 :. All)` in the source code;
-- i.e. that particular example would translate to `[All, Fixed, Fixed]`.
--
+-- The result is that the "fastest varying" dimension is on the left
+-- in this representation.
type SliceType = [SliceComponent]
data SliceComponent = Fixed | All
deriving (Eq,Show,Read,Generic)
-- TEMP / OLD:
--- The read left-to-right, in the same
+-- They read left-to-right, in the same
-- order that one would write `(Z :. 3 :. 4 :. All)` in the source code.
-- That particular example would translate to `[Fixed, Fixed, All]`.
@@ -297,10 +306,13 @@ data ArrayPayload =
| ArrayPayloadDouble (RawData Double)
| ArrayPayloadChar (RawData Char)
| ArrayPayloadBool (RawData Word8) -- Word8's represent bools.
- | ArrayPayloadUnit -- Dummy placeholder value.
+-- | ArrayPayloadUnit -- Dummy placeholder value.
+--
+-- TODO -- Add C-types. But as of this date [2012.05.21], Accelerate
+-- support for these types is incomplete, so we omit them here as well.
--
- -- TODO: UArray doesn't offer cast like IOArray. It would be nice
- -- to make all arrays canonicalized to a data buffer of Word8's:
+-- TODO: UArray doesn't offer cast like IOArray. It would be nice
+-- to make all arrays canonicalized to a data buffer of Word8's:
deriving (Show, Read, Eq)
-- | This is our Haskell representation of raw, contiguous data.
@@ -383,7 +395,7 @@ test = read "array (1,5) [(1,200),(2,201),(3,202),(4,203),(5,204)]" :: U.UArray
payloadLength :: ArrayPayload -> Int
payloadLength payl =
case payl of
- ArrayPayloadUnit -> 0
+-- ArrayPayloadUnit -> 0
ArrayPayloadInt arr -> arrLen arr
ArrayPayloadInt8 arr -> arrLen arr
ArrayPayloadInt16 arr -> arrLen arr
@@ -426,7 +438,7 @@ applyToPayload fn payl = applyToPayload2 (\ a _ -> fn a) payl
applyToPayload2 :: (forall a . UArray Int a -> (Int -> Const) -> UArray Int a) -> ArrayPayload -> ArrayPayload
applyToPayload2 fn payl =
case payl of
- ArrayPayloadUnit -> ArrayPayloadUnit
+-- ArrayPayloadUnit -> ArrayPayloadUnit
ArrayPayloadInt arr -> ArrayPayloadInt (fn arr (\i -> I (arr U.! i)))
ArrayPayloadInt8 arr -> ArrayPayloadInt8 (fn arr (\i -> I8 (arr U.! i)))
ArrayPayloadInt16 arr -> ArrayPayloadInt16 (fn arr (\i -> I16 (arr U.! i)))
@@ -450,25 +462,24 @@ applyToPayload2 fn payl =
applyToPayload3 :: (Int -> (Int -> Const) -> [Const]) -> ArrayPayload -> ArrayPayload
-- TODO!! The same-type-as-input restriction could be relaxed.
applyToPayload3 fn payl =
- let len = payloadLength payl in
case payl of
- ArrayPayloadUnit -> ArrayPayloadUnit
- ArrayPayloadInt arr -> ArrayPayloadInt (U.listArray (0,len) (map unI $ fn len (\i -> I (arr U.! i))))
- ArrayPayloadInt8 arr -> ArrayPayloadInt8 (U.listArray (0,len) (map unI8 $ fn len (\i -> I8 (arr U.! i))))
- ArrayPayloadInt16 arr -> ArrayPayloadInt16 (U.listArray (0,len) (map unI16$ fn len (\i -> I16 (arr U.! i))))
- ArrayPayloadInt32 arr -> ArrayPayloadInt32 (U.listArray (0,len) (map unI32$ fn len (\i -> I32 (arr U.! i))))
- ArrayPayloadInt64 arr -> ArrayPayloadInt64 (U.listArray (0,len) (map unI64$ fn len (\i -> I64 (arr U.! i))))
- ArrayPayloadWord arr -> ArrayPayloadWord (U.listArray (0,len) (map unW $ fn len (\i -> W (arr U.! i))))
- ArrayPayloadWord8 arr -> ArrayPayloadWord8 (U.listArray (0,len) (map unW8 $ fn len (\i -> W8 (arr U.! i))))
- ArrayPayloadWord16 arr -> ArrayPayloadWord16 (U.listArray (0,len) (map unW16$ fn len (\i -> W16 (arr U.! i))))
- ArrayPayloadWord32 arr -> ArrayPayloadWord32 (U.listArray (0,len) (map unW32$ fn len (\i -> W32 (arr U.! i))))
- ArrayPayloadWord64 arr -> ArrayPayloadWord64 (U.listArray (0,len) (map unW64$ fn len (\i -> W64 (arr U.! i))))
- ArrayPayloadFloat arr -> ArrayPayloadFloat (U.listArray (0,len) (map unF$ fn len (\i -> F (arr U.! i))))
- ArrayPayloadDouble arr -> ArrayPayloadDouble (U.listArray (0,len) (map unD$ fn len (\i -> D (arr U.! i))))
- ArrayPayloadChar arr -> ArrayPayloadChar (U.listArray (0,len) (map unC$ fn len (\i -> C (arr U.! i))))
- ArrayPayloadBool arr -> ArrayPayloadBool (U.listArray (0,len)
- (map fromBool$ fn len (\i -> toBool (arr U.! i))))
+-- ArrayPayloadUnit -> ArrayPayloadUnit
+ ArrayPayloadInt arr -> ArrayPayloadInt (fromL (map unI $ fn len (\i -> I (arr U.! i))))
+ ArrayPayloadInt8 arr -> ArrayPayloadInt8 (fromL (map unI8 $ fn len (\i -> I8 (arr U.! i))))
+ ArrayPayloadInt16 arr -> ArrayPayloadInt16 (fromL (map unI16$ fn len (\i -> I16 (arr U.! i))))
+ ArrayPayloadInt32 arr -> ArrayPayloadInt32 (fromL (map unI32$ fn len (\i -> I32 (arr U.! i))))
+ ArrayPayloadInt64 arr -> ArrayPayloadInt64 (fromL (map unI64$ fn len (\i -> I64 (arr U.! i))))
+ ArrayPayloadWord arr -> ArrayPayloadWord (fromL (map unW $ fn len (\i -> W (arr U.! i))))
+ ArrayPayloadWord8 arr -> ArrayPayloadWord8 (fromL (map unW8 $ fn len (\i -> W8 (arr U.! i))))
+ ArrayPayloadWord16 arr -> ArrayPayloadWord16 (fromL (map unW16$ fn len (\i -> W16 (arr U.! i))))
+ ArrayPayloadWord32 arr -> ArrayPayloadWord32 (fromL (map unW32$ fn len (\i -> W32 (arr U.! i))))
+ ArrayPayloadWord64 arr -> ArrayPayloadWord64 (fromL (map unW64$ fn len (\i -> W64 (arr U.! i))))
+ ArrayPayloadFloat arr -> ArrayPayloadFloat (fromL (map unF $ fn len (\i -> F (arr U.! i))))
+ ArrayPayloadDouble arr -> ArrayPayloadDouble (fromL (map unD $ fn len (\i -> D (arr U.! i))))
+ ArrayPayloadChar arr -> ArrayPayloadChar (fromL (map unC $ fn len (\i -> C (arr U.! i))))
+ ArrayPayloadBool arr -> ArrayPayloadBool (fromL (map fromBool$ fn len (\i -> toBool (arr U.! i))))
where
+ len = payloadLength payl
unI (I x) = x
unI8 (I8 x) = x
unI16 (I16 x) = x
@@ -483,7 +494,69 @@ applyToPayload3 fn payl =
unD (D x) = x
unC (C x) = x
unB (B x) = x
- toBool 0 = B False
- toBool _ = B True
- fromBool (B False) = 0
- fromBool (B True) = 1
+
+fromL l = U.listArray (0,length l - 1) l
+toBool 0 = B False
+toBool _ = B True
+fromBool (B False) = 0
+fromBool (B True) = 1
+
+
+-- | Create an array of with the given dimensions and many copies of
+-- the same element. This deals with constructing the appropriate
+-- type of payload to match the type of constant (which is otherwise
+-- a large case statement).
+replicate :: [Int] -> Const -> AccArray
+replicate dims const = AccArray dims (payload const)
+ where
+ len = foldl (*) 1 dims
+ payload const =
+ case const of
+ I x -> [ArrayPayloadInt (fromL$ Prelude.replicate len x)]
+ I8 x -> [ArrayPayloadInt8 (fromL$ Prelude.replicate len x)]
+ I16 x -> [ArrayPayloadInt16 (fromL$ Prelude.replicate len x)]
+ I32 x -> [ArrayPayloadInt32 (fromL$ Prelude.replicate len x)]
+ I64 x -> [ArrayPayloadInt64 (fromL$ Prelude.replicate len x)]
+ W x -> [ArrayPayloadWord (fromL$ Prelude.replicate len x)]
+ W8 x -> [ArrayPayloadWord8 (fromL$ Prelude.replicate len x)]
+ W16 x -> [ArrayPayloadWord16 (fromL$ Prelude.replicate len x)]
+ W32 x -> [ArrayPayloadWord32 (fromL$ Prelude.replicate len x)]
+ W64 x -> [ArrayPayloadWord64 (fromL$ Prelude.replicate len x)]
+ F x -> [ArrayPayloadFloat (fromL$ Prelude.replicate len x)]
+ D x -> [ArrayPayloadDouble (fromL$ Prelude.replicate len x)]
+ C x -> [ArrayPayloadChar (fromL$ Prelude.replicate len x)]
+ B x -> [ArrayPayloadBool (fromL$ Prelude.replicate len (fromBool const))]
+ Tup ls -> concatMap payload ls
+
+-- TODO -- add all C array types to the ArrayPayload type:
+-- | CF CFloat | CD CDouble
+-- | CS CShort | CI CInt | CL CLong | CLL CLLong
+-- | CUS CUShort | CUI CUInt | CUL CULong | CULL CULLong
+-- | CC CChar | CSC CSChar | CUC CUChar
+
+
+----------------------------------------------------------------------------------------------------
+
+-- Note: an alternate approach to the large sum of payload types would
+-- be to cast them all to a UArray of bytes. There is not direct
+-- support for this in the UArray module, but we could accomplish it
+-- with something like the following:
+castUArray :: forall ix a b . (Ix ix, IArray UArray a, IArray UArray b,
+ IA.MArray IA.IOUArray a IO, IA.MArray IA.IOUArray b IO)
+ => UArray ix a -> UArray ix b
+castUArray uarr = unsafePerformIO $
+ do thawed :: IA.IOUArray ix a <- MA.unsafeThaw uarr
+ cast :: IA.IOUArray ix b <- Un.castIOUArray thawed
+ froze :: UArray ix b <- MA.unsafeFreeze cast
+ return froze
+
+-- Like Data.Vector.generate, but for `UArray`s. Unfortunately, this
+-- requires extra class constraints for `IOUArray` as well.
+uarrGenerate :: (IArray UArray a, IA.MArray IA.IOUArray a IO)
+ => Int -> (Int -> a) -> UArray Int a
+uarrGenerate len fn = unsafePerformIO $
+ do marr :: IA.IOUArray Int a <- MA.newArray_ (0,len)
+ let loop (-1) = MA.unsafeFreeze marr
+ loop i = do MA.writeArray marr i (fn i)
+ loop (i-1)
+ loop (len-1)
View
212 Data/Array/Accelerate/SimpleConverter.hs
@@ -35,7 +35,7 @@ import qualified Data.Array.Accelerate.SimpleAST as S
import qualified Data.List as L
--- import Debug.Trace(trace)
+import Debug.Trace(trace)
-- tracePrint s x = trace (s ++ show x) x
--------------------------------------------------------------------------------
@@ -124,70 +124,19 @@ convertAcc (OpenAcc cacc) = convertPreOpenAcc cacc
------------------------------------------------------------
-- Array creation:
-
-- These should include types.
-
+
Generate sh f -> S.Generate (getAccTypePre eacc)
<$> convertExp sh
<*> convertFun f
-- This is real live runtime array data:
- -- orig@(Use (arrrepr :: Sug.ArrRepr a)) | (_ :: PreOpenAcc OpenAcc aenv a) <- orig ->
- Use (arrrepr :: Sug.ArrRepr a) ->
- return$ S.Use ty (S.AccArray shp payloads)
- where
- shp = case L.group shps of
- [] -> []
- [(hd : _gr1)] -> hd
- ls -> error$"Use: corrupt Accelerate array -- arrays components did not have identical shape:"
- ++ show (concat ls)
- (shps, payloads) = cvt2 repOf actualArr
-
- ty = convertArrayType repOf
- repOf = Sug.arrays actualArr :: Sug.ArraysR (Sug.ArrRepr a)
- actualArr = Sug.toArr arrrepr :: a
-
- cvt :: Sug.ArraysR a' -> a' -> ([[Int]],[S.ArrayPayload])
- cvt Sug.ArraysRunit () = ([],[])
- cvt (Sug.ArraysRpair r1 r2) (a1, a2) = cvt r1 a1 `combine` cvt r2 a2
- cvt Sug.ArraysRarray arr = convertArrayValue arr
-
- -- -- Takes an Array representation and its reified type:
- cvt2 :: (Sug.Arrays a') => Sug.ArraysR (Sug.ArrRepr a') -> a' -> ([[Int]],[S.ArrayPayload])
- cvt2 tyReified arr =
- case (tyReified, Sug.fromArr arr) of
- (Sug.ArraysRunit, ()) -> ([],[])
- (Sug.ArraysRpair r1 r2, (a1, a2)) -> cvt r1 a1 `combine` cvt r2 a2
- (Sug.ArraysRarray, arr2) -> convertArrayValue arr2
-
- combine (a,b) (x,y) = (a++x, b++y)
-
- convertArrayValue :: forall dim e . (Sug.Elt e) => Sug.Array dim e -> ([[Int]],[S.ArrayPayload])
- convertArrayValue (Sug.Array shpVal adata) =
- ([Sug.shapeToList (Sug.toElt shpVal :: dim)],
- useR arrayElt adata)
- where
- -- This [mandatory] type signature forces the array data to be the
- -- same type as the ArrayElt Representation (elt ~ elt):
- useR :: ArrayEltR elt -> ArrayData elt -> [S.ArrayPayload]
- useR (ArrayEltRpair aeR1 aeR2) ad =
- (useR aeR1 (fstArrayData ad)) ++
- (useR aeR2 (sndArrayData ad))
- useR ArrayEltRunit _ = [S.ArrayPayloadUnit]
- useR ArrayEltRint (AD_Int x) = [S.ArrayPayloadInt x]
- useR ArrayEltRint8 (AD_Int8 x) = [S.ArrayPayloadInt8 x]
- useR ArrayEltRint16 (AD_Int16 x) = [S.ArrayPayloadInt16 x]
- useR ArrayEltRint32 (AD_Int32 x) = [S.ArrayPayloadInt32 x]
- useR ArrayEltRint64 (AD_Int64 x) = [S.ArrayPayloadInt64 x]
- useR ArrayEltRword (AD_Word x) = [S.ArrayPayloadWord x]
- useR ArrayEltRword8 (AD_Word8 x) = [S.ArrayPayloadWord8 x]
- useR ArrayEltRword16 (AD_Word16 x) = [S.ArrayPayloadWord16 x]
- useR ArrayEltRword32 (AD_Word32 x) = [S.ArrayPayloadWord32 x]
- useR ArrayEltRword64 (AD_Word64 x) = [S.ArrayPayloadWord64 x]
- useR ArrayEltRfloat (AD_Float x) = [S.ArrayPayloadFloat x]
- useR ArrayEltRdouble (AD_Double x) = [S.ArrayPayloadDouble x]
- useR ArrayEltRbool (AD_Bool x) = [S.ArrayPayloadBool x]
- useR ArrayEltRchar (AD_Char x) = [S.ArrayPayloadChar x]
+ Use (arrrepr :: Sug.ArrRepr a) ->
+ -- This is rather odd, but we need to have a dummy return
+ -- value to avoid errors about ArrRepr type functions not
+ -- being injective.
+ let (ty,arr,_::Phantom a) = unpackArray arrrepr in
+ return$ S.Use ty arr
-- End Array creation prims.
------------------------------------------------------------
@@ -349,17 +298,69 @@ convertExp e =
Cond c t ex -> S.ECond <$> convertExp c
<*> convertExp t
<*> convertExp ex
- PrimConst c -> return$ S.EConst $
- case c of
- PrimMinBound _ -> S.MinBound
- PrimMaxBound _ -> S.MaxBound
- PrimPi _ -> S.Pi
-
+
IndexScalar acc eix -> S.EIndexScalar <$> convertAcc acc
<*> convertExp eix
Shape acc -> S.EShape <$> convertAcc acc
ShapeSize acc -> S.EShapeSize <$> convertExp acc
+ -- We are committed to specific binary representations of numeric
+ -- types anyway, so we simply encode special constants here,
+ -- rather than preserving their specialness:
+ PrimConst c -> return$ S.EConst $
+ case (c, getExpType e) of
+ (PrimPi _, S.TFloat) -> S.F pi
+ (PrimPi _, S.TDouble) -> S.D pi
+ (PrimPi _, S.TCFloat) -> S.CF pi
+ (PrimPi _, S.TCDouble) -> S.CD pi
+ (PrimMinBound _, S.TInt) -> S.I minBound
+ (PrimMinBound _, S.TInt8) -> S.I8 minBound
+ (PrimMinBound _, S.TInt16) -> S.I16 minBound
+ (PrimMinBound _, S.TInt32) -> S.I32 minBound
+ (PrimMinBound _, S.TInt64) -> S.I64 minBound
+ (PrimMinBound _, S.TWord) -> S.W minBound
+ (PrimMinBound _, S.TWord8) -> S.W8 minBound
+ (PrimMinBound _, S.TWord16) -> S.W16 minBound
+ (PrimMinBound _, S.TWord32) -> S.W32 minBound
+ (PrimMinBound _, S.TWord64) -> S.W64 minBound
+ (PrimMinBound _, S.TCShort) -> S.CS minBound
+ (PrimMinBound _, S.TCInt ) -> S.CI minBound
+ (PrimMinBound _, S.TCLong ) -> S.CL minBound
+ (PrimMinBound _, S.TCLLong) -> S.CLL minBound
+ (PrimMinBound _, S.TCUShort) -> S.CUS minBound
+ (PrimMinBound _, S.TCUInt ) -> S.CUI minBound
+ (PrimMinBound _, S.TCULong ) -> S.CUL minBound
+ (PrimMinBound _, S.TCULLong) -> S.CULL minBound
+ (PrimMinBound _, S.TChar ) -> S.C minBound
+ (PrimMinBound _, S.TCChar ) -> S.CC minBound
+ (PrimMinBound _, S.TCSChar) -> S.CSC minBound
+ (PrimMinBound _, S.TCUChar) -> S.CUC minBound
+ (PrimMaxBound _, S.TInt) -> S.I maxBound
+ (PrimMaxBound _, S.TInt8) -> S.I8 maxBound
+ (PrimMaxBound _, S.TInt16) -> S.I16 maxBound
+ (PrimMaxBound _, S.TInt32) -> S.I32 maxBound
+ (PrimMaxBound _, S.TInt64) -> S.I64 maxBound
+ (PrimMaxBound _, S.TWord) -> S.W maxBound
+ (PrimMaxBound _, S.TWord8) -> S.W8 maxBound
+ (PrimMaxBound _, S.TWord16) -> S.W16 maxBound
+ (PrimMaxBound _, S.TWord32) -> S.W32 maxBound
+ (PrimMaxBound _, S.TWord64) -> S.W64 maxBound
+ (PrimMaxBound _, S.TCShort) -> S.CS maxBound
+ (PrimMaxBound _, S.TCInt ) -> S.CI maxBound
+ (PrimMaxBound _, S.TCLong ) -> S.CL maxBound
+ (PrimMaxBound _, S.TCLLong) -> S.CLL maxBound
+ (PrimMaxBound _, S.TCUShort) -> S.CUS maxBound
+ (PrimMaxBound _, S.TCUInt ) -> S.CUI maxBound
+ (PrimMaxBound _, S.TCULong ) -> S.CUL maxBound
+ (PrimMaxBound _, S.TCULLong) -> S.CULL maxBound
+ (PrimMaxBound _, S.TChar ) -> S.C maxBound
+ (PrimMaxBound _, S.TCChar ) -> S.CC maxBound
+ (PrimMaxBound _, S.TCSChar) -> S.CSC maxBound
+ (PrimMaxBound _, S.TCUChar) -> S.CUC maxBound
+ (PrimMinBound _,ty) -> error$"Internal error: no minBound for type"++show ty
+ (PrimMaxBound _,ty) -> error$"Internal error: no maxBound for type"++show ty
+ (PrimPi _,ty) -> error$"Internal error: no pi constant for type"++show ty
+
-- Convert a tuple expression to our simpler Tuple representation (containing a list):
-- convertTuple :: Tuple (PreOpenExp acc env aenv) t' -> S.AExp
@@ -554,3 +555,84 @@ convertFun = loop []
v <- envLookup 0
loop ((v,sty) : acc) f2
return x
+
+--------------------------------------------------------------------------------
+-- Convert Accelerate Array Data
+--------------------------------------------------------------------------------
+
+-- | Used only for communicating type information.
+data Phantom a = Phantom
+
+-- | This converts Accelerate Array data. It has an odd return type
+-- to avoid type-family related type errors.
+unpackArray :: forall a . (Sug.Arrays a) => Sug.ArrRepr a -> (S.Type, S.AccArray, Phantom a)
+unpackArray arrrepr = (ty, S.AccArray shp payloads,
+ Phantom :: Phantom a)
+ where
+ shp = case L.group shps of
+ [] -> []
+ [(hd : _gr1)] -> hd
+ ls -> error$"Use: corrupt Accelerate array -- arrays components did not have identical shape:"
+ ++ show (concat ls)
+ (shps, payloads) = cvt2 repOf actualArr
+ ty = convertArrayType repOf
+ repOf = Sug.arrays actualArr :: Sug.ArraysR (Sug.ArrRepr a)
+ actualArr = Sug.toArr arrrepr :: a
+
+ -- cvt and cvt2 return a list of shapes together with a list of raw data payloads.
+ --
+ -- I'm afraid I don't understand the two-level pairing that
+ -- is going on here (ArraysRpair + ArraysEltRpair)
+ cvt :: Sug.ArraysR a' -> a' -> ([[Int]],[S.ArrayPayload])
+ cvt Sug.ArraysRunit () = ([],[])
+ cvt (Sug.ArraysRpair r1 r2) (a1, a2) = cvt r1 a1 `combine` cvt r2 a2
+ cvt Sug.ArraysRarray arr = cvt3 arr
+
+ -- Takes an Array representation and its reified type:
+ cvt2 :: (Sug.Arrays a') => Sug.ArraysR (Sug.ArrRepr a') -> a' -> ([[Int]],[S.ArrayPayload])
+ cvt2 tyReified arr =
+ case (tyReified, Sug.fromArr arr) of
+ (Sug.ArraysRunit, ()) -> ([],[])
+ (Sug.ArraysRpair r1 r2, (a1, a2)) -> cvt r1 a1 `combine` cvt r2 a2
+ (Sug.ArraysRarray, arr2) -> cvt3 arr2
+
+ combine (a,b) (x,y) = (a++x, b++y)
+
+ cvt3 :: forall dim e . (Sug.Elt e) => Sug.Array dim e -> ([[Int]],[S.ArrayPayload])
+ cvt3 (Sug.Array shpVal adata) =
+ ([Sug.shapeToList (Sug.toElt shpVal :: dim)],
+ useR arrayElt adata)
+ where
+ -- This [mandatory] type signature forces the array data to be the
+ -- same type as the ArrayElt Representation (elt ~ elt):
+ useR :: ArrayEltR elt -> ArrayData elt -> [S.ArrayPayload]
+ useR (ArrayEltRpair aeR1 aeR2) ad =
+ (useR aeR1 (fstArrayData ad)) ++
+ (useR aeR2 (sndArrayData ad))
+ -- useR ArrayEltRunit _ = [S.ArrayPayloadUnit]
+ useR ArrayEltRunit _ = []
+ useR ArrayEltRint (AD_Int x) = [S.ArrayPayloadInt x]
+ useR ArrayEltRint8 (AD_Int8 x) = [S.ArrayPayloadInt8 x]
+ useR ArrayEltRint16 (AD_Int16 x) = [S.ArrayPayloadInt16 x]
+ useR ArrayEltRint32 (AD_Int32 x) = [S.ArrayPayloadInt32 x]
+ useR ArrayEltRint64 (AD_Int64 x) = [S.ArrayPayloadInt64 x]
+ useR ArrayEltRword (AD_Word x) = [S.ArrayPayloadWord x]
+ useR ArrayEltRword8 (AD_Word8 x) = [S.ArrayPayloadWord8 x]
+ useR ArrayEltRword16 (AD_Word16 x) = [S.ArrayPayloadWord16 x]
+ useR ArrayEltRword32 (AD_Word32 x) = [S.ArrayPayloadWord32 x]
+ useR ArrayEltRword64 (AD_Word64 x) = [S.ArrayPayloadWord64 x]
+ useR ArrayEltRfloat (AD_Float x) = [S.ArrayPayloadFloat x]
+ useR ArrayEltRdouble (AD_Double x) = [S.ArrayPayloadDouble x]
+ useR ArrayEltRbool (AD_Bool x) = [S.ArrayPayloadBool x]
+ useR ArrayEltRchar (AD_Char x) = [S.ArrayPayloadChar x]
+
+
+-- | Inverse of previous function, repack the simplified data
+-- representation with the type information necessary to form a proper
+-- Accelerate array.
+packArray :: forall a . (Sug.Arrays a) => (S.Type, S.AccArray, Phantom a) -> Sug.ArrRepr a
+packArray (ty, accarray, _) =
+ undefined
+
+--------------------------------------------------------------------------------
+
View
110 Data/Array/Accelerate/SimpleInterp.hs
@@ -1,4 +1,4 @@
-
+{-# LANGUAGE CPP #-}
-- An example interpreter for the simplified AST.
module Data.Array.Accelerate.SimpleInterp
@@ -9,7 +9,7 @@ module Data.Array.Accelerate.SimpleInterp
import Data.Array.Accelerate.Smart (Acc)
import qualified Data.Array.Accelerate.Array.Sugar as Sug
-import Data.Array.Accelerate.SimpleAST
+import Data.Array.Accelerate.SimpleAST as S
import Data.Array.Accelerate.SimpleConverter (convertToSimpleAST)
import qualified Data.Map as M
@@ -30,16 +30,12 @@ lookup = error"lookup"
data Value = TupVal [Value]
| ArrVal AccArray
| Scalar { unScalar :: Const }
+ deriving Show
--------------------------------------------------------------------------------
-singleton (Scalar s) = ArrVal (error "finish me")
-
--- repackResult
-
run :: Sug.Arrays a => Acc a -> a
run acc = error (show (evalA M.empty (convertToSimpleAST acc)))
-
evalA :: Env -> AExp -> AccArray
evalA env ae = finalArr
where
@@ -51,7 +47,8 @@ evalA env ae = finalArr
Vr v -> env M.! v
Let vr ty lhs bod -> ArrVal$ evalA (M.insert vr (loop lhs) env) bod
- Unit e -> singleton (evalE M.empty e)
+ Unit e -> case evalE M.empty e of
+ Scalar c -> ArrVal$ S.replicate [] c
ArrayTuple aes -> TupVal (map loop aes)
Cond e1 ae2 ae3 -> case evalE env e1 of
@@ -77,33 +74,33 @@ evalA env ae = finalArr
-- Shave off leftmost dim in 'sh' list
-- (the rightmost dim in the user's (Z :. :.) expression):
Fold (Lam [(v1,_),(v2,_)] bodE) ex ae ->
- trace ("FOLDING, shape "++show (innerdim:sh') ++ " arr "++show payloads++"\n") $
+ trace ("FOLDING, shape "++show (innerdim:sh') ++ " lens "++
+ show (alllens, L.group alllens) ++" arr "++show payloads++"\n") $
case payloads of
- [] -> error "Empty payloads!"
+ [] -> error "Empty payloads!"
_ -> ArrVal (AccArray sh' payloads')
where initacc = evalE env ex
AccArray (innerdim:sh') payloads = evalA env ae -- Must be >0 dimensional.
+ payloads' = map (applyToPayload3 buildFolded) payloads
- [len:_] = tracePrint"GROUP"$ L.group $ map payloadLength payloads
+ alllens = map payloadLength payloads
+ len = case L.group alllens of
+ [len:_] -> len
+ x -> error$ "Corrupt Accelerate array. Non-homogenous payload lengths: "++show x
-- Cut the total size down by whatever the length of the inner dimension is:
newlen = len `quot` innerdim
- -- dofold :: UArray Int Float -> A.Array Int Value -- TODO: generalize type
- -- dofold arr = A.listArray (0, newlen) $
- -- [ innerloop (\i -> F (arr U.! i)) (innerdim * i) 0 initacc | i <- [0..newlen] ]
-
- payloads' = map (applyToPayload3 buildFolded) payloads
-
buildFolded :: Int -> (Int -> Const) -> [Const]
- buildFolded _ lookup =
- [ unScalar (innerloop lookup (innerdim * i) 0 initacc)
+ buildFolded _ lookup = tracePrint "\nbuildFOLDED : "$
+ [ unScalar (innerloop lookup (innerdim * i) innerdim initacc)
| i <- [0..newlen] ]
-- The innermost dim is always contiguous in memory.
innerloop :: (Int -> Const) -> Int -> Int -> Value -> Value
innerloop _ _ 0 acc = acc
innerloop lookup offset count acc =
+ trace ("Inner looping "++show(offset,count,acc))$
innerloop lookup (offset+1) (count-1) $
evalE (M.insert v1 acc $
M.insert v2 (Scalar$ lookup offset) env)
@@ -190,25 +187,74 @@ evalPrim p [] =
evalPrim p es =
case p of
- NP Add -> Scalar (foldl1 plus (map unScalar es))
-
-plus :: Const -> Const -> Const
-plus (I a) (I b) = I (a+b)
-
--- Todo: special constants: minBound, maxBound, pi
-
--- data Prim = NP NumPrim
+ NP Add -> Scalar (foldl1 add (map unScalar es))
+ NP Mul -> Scalar (foldl1 mul (map unScalar es))
+ NP Neg -> Scalar (neg $ unScalar $ head es)
+ NP Abs -> Scalar (absv $ unScalar $ head es)
+ NP Sig -> Scalar (sig $ unScalar $ head es)
-- | IP IntPrim
-- | FP FloatPrim
-- | SP ScalarPrim
-- | BP BoolPrim
-- | OP OtherPrim
--- deriving (Read,Show,Eq,Generic)
+
+
+add :: Const -> Const -> Const
+#define ADD(X) add (X a) (X b) = X (a+b);
+ADD(I) ADD(I8) ADD(I16) ADD(I32) ADD(I64)
+ADD(W) ADD(W8) ADD(W16) ADD(W32) ADD(W64)
+ADD(F) ADD(D) ADD(CF) ADD(CD)
+ADD(CS) ADD(CI) ADD(CL) ADD(CLL)
+ADD(CUS) ADD(CUI) ADD(CUL) ADD(CULL)
+ADD(CC) ADD(CUC) ADD(CSC)
+add a b = error $ "add: unsupported combination of values: "++show (a,b)
+
+mul :: Const -> Const -> Const
+#define MUL(X) mul (X a) (X b) = X (a*b);
+MUL(I) MUL(I8) MUL(I16) MUL(I32) MUL(I64)
+MUL(W) MUL(W8) MUL(W16) MUL(W32) MUL(W64)
+MUL(F) MUL(D) MUL(CF) MUL(CD)
+MUL(CS) MUL(CI) MUL(CL) MUL(CLL)
+MUL(CUS) MUL(CUI) MUL(CUL) MUL(CULL)
+MUL(CC) MUL(CUC) MUL(CSC)
+mul a b = error $ "mul: unsupported combination of values: "++show(a,b)
+
+neg :: Const -> Const
+#define NEG(X) neg (X a) = X (- a);
+NEG(I) NEG(I8) NEG(I16) NEG(I32) NEG(I64)
+NEG(W) NEG(W8) NEG(W16) NEG(W32) NEG(W64)
+NEG(F) NEG(D) NEG(CF) NEG(CD)
+NEG(CS) NEG(CI) NEG(CL) NEG(CLL)
+NEG(CUS) NEG(CUI) NEG(CUL) NEG(CULL)
+NEG(CC) NEG(CUC) NEG(CSC)
+neg a = error $ "negate: unsupported value: "++show a
+
+absv :: Const -> Const
+#define ABS(X) absv (X a) = X (Prelude.abs a);
+ABS(I) ABS(I8) ABS(I16) ABS(I32) ABS(I64)
+ABS(W) ABS(W8) ABS(W16) ABS(W32) ABS(W64)
+ABS(F) ABS(D) ABS(CF) ABS(CD)
+ABS(CS) ABS(CI) ABS(CL) ABS(CLL)
+ABS(CUS) ABS(CUI) ABS(CUL) ABS(CULL)
+ABS(CC) ABS(CUC) ABS(CSC)
+absv a = error $ "abs: unsupported value: "++show a
+
+sig :: Const -> Const
+#define SIG(X) sig (X a) = X (signum a);
+SIG(I) SIG(I8) SIG(I16) SIG(I32) SIG(I64)
+SIG(W) SIG(W8) SIG(W16) SIG(W32) SIG(W64)
+SIG(F) SIG(D) SIG(CF) SIG(CD)
+SIG(CS) SIG(CI) SIG(CL) SIG(CLL)
+SIG(CUS) SIG(CUI) SIG(CUL) SIG(CULL)
+SIG(CC) SIG(CUC) SIG(CSC)
+sig a = error $ "sig: unsupported value: "++show a
+
+
+
+
+
--- -- | Primitives that operate on /all/ numeric types.
--- -- Neg/Abs/Sig are unary:
--- data NumPrim = Add | Mul | Neg | Abs | Sig
--- deriving (Read,Show,Eq,Generic)
+
-- -- | Primitive integral-only operations.
-- -- All binops except BNot, shifts and rotates take an Int constant as second arg:
View
22 Test.hs
@@ -48,8 +48,8 @@ p1b :: Acc (Vector Float)
p1b = let xs = use$ fromList (Z :. (2::Int) :. (5::Int)) [1..10::Float]
in fold (+) 0 xs
t1b :: S.AExp
-t1b = convertToSimpleAST p1
-r1b = I.run p1
+t1b = convertToSimpleAST p1b
+r1b = I.run p1b
----------------------------------------
@@ -131,9 +131,25 @@ r7 = I.run p7
-- EIndexHeadDynamic (EVr v1)]))
-- (Vr a0))
-
-- TODO -- still need to generate an IndexCons node.
+----------------------------------------
+
+-- This shows an odd difference in staging:
+p8 :: Acc (Scalar Float)
+p8 = unit$ pi + (constant pi :: Exp Float) *
+-- pi -- (signum pi)
+ 33
+-- negate (abs (signum 33))
+
+t8 = convertToSimpleAST p8
+r8 = I.run p8
+
+-- Prim arguments don't need to directly be tuple expressions:
+-- unit ((+) (let x0 = pi in (x0, 3.1415927 * x0)))
+
+
+
--------------------------------------------------------------------------------
padleft n str | length str >= n = str

No commit comments for this range

Something went wrong with that request. Please try again.